Compare commits

...

1 Commits

Author SHA1 Message Date
Khalil Meftah 17e217d175 refactor(policies): clean MolmoAct2 to follow EO1/TOPReward patterns
Align the MolmoAct2 implementation with lerobot codebase conventions:

- Rename hf_model/ to molmoact2_hf_model/
- Slim config: move all I/O and runtime logic to modeling
- Remove blanket  from 8 vendored files, fix 66 lint issues
- Deduplicate _hf_token() and _resolve_checkpoint_location()
- Make huggingface_hub imports lazy
- Remove custom MolmoAct2CosineDecayWithWarmupSchedulerConfig, use base class
- Extract 13 static/classmethods from MolmoAct2Policy to free functions
- Replace print() with logger in vendored action_tokenizer
- Add module docstrings, class docstring, and key method docstrings
- Add module-level loggers to modeling and processor
- Fix docs: pip to uv install, deduplicate README symlink
- Remove shebangs from all files
2026-06-05 16:31:03 +02:00
15 changed files with 611 additions and 694 deletions
+3 -3
View File
@@ -17,7 +17,7 @@ the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
Install LeRobot with the MolmoAct2 optional dependencies:
```bash
pip install -e ".[molmoact2]"
uv sync --locked --extra molmoact2
```
To run the models in this repository, you need an NVIDIA GPU. The measurements
@@ -46,8 +46,8 @@ The repo has been tested with Ubuntu 22.04.
To use MolmoAct2 in a LeRobot training config, set:
```python
policy.type=molmoact2
```bash
--policy.type=molmoact2
```
## Training
+1 -1
View File
@@ -1 +1 @@
../../../../docs/source/policy_molmoact2_README.md
../../../../docs/source/molmoact2.mdx
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,16 +14,9 @@
from __future__ import annotations
import json
import math
import os
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import (
AdamWConfig,
@@ -37,146 +28,6 @@ from lerobot.utils.constants import ACTION, OBS_STATE
from ..rtc.configuration_rtc import RTCConfig
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
@dataclass
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
training steps"; build() is the first point where num_training_steps is known.
"""
num_decay_steps: int | None
def build(self, optimizer, num_training_steps: int):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.peak_lr,
decay_lr=self.decay_lr,
num_warmup_steps=self.num_warmup_steps,
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
).build(optimizer, num_training_steps=num_training_steps)
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
@PreTrainedConfig.register_subclass("molmoact2")
@dataclass
@@ -255,7 +106,7 @@ class MolmoAct2Config(PreTrainedConfig):
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 200
scheduler_decay_steps: int | None = None
scheduler_decay_steps: int = 100_000
scheduler_decay_lr: float = 1e-6
normalization_mapping: dict[str, NormalizationMode] = field(
@@ -333,41 +184,6 @@ class MolmoAct2Config(PreTrainedConfig):
if self.max_sequence_length is not None and self.max_sequence_length < 1:
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
def inferred_max_sequence_length(
self,
*,
num_images: int | None = None,
state_dim: int | None = None,
action_dim: int | None = None,
action_horizon: int | None = None,
include_discrete_action: bool | None = None,
) -> int:
if self.max_sequence_length is not None:
return int(self.max_sequence_length)
if num_images is None:
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim is None:
state_feature = self.robot_state_feature
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
if action_dim is None:
action_feature = self.action_feature
action_dim = (
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
)
if action_horizon is None:
action_horizon = self.chunk_size
if include_discrete_action is None:
include_discrete_action = self.action_mode in {"discrete", "both"}
return infer_molmoact2_max_sequence_length(
num_images=int(num_images),
state_dim=int(state_dim),
action_dim=int(action_dim),
action_horizon=int(action_horizon),
include_discrete_action=bool(include_discrete_action),
)
@property
def observation_delta_indices(self) -> None:
return None
@@ -390,7 +206,7 @@ class MolmoAct2Config(PreTrainedConfig):
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
@@ -426,94 +242,3 @@ class MolmoAct2Config(PreTrainedConfig):
shape=(self.expected_max_action_dim,),
)
self.output_features[ACTION] = action_feature
def apply_norm_tag_metadata(self) -> None:
if not str(self.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.checkpoint_path,
revision=self.checkpoint_revision,
force_download=bool(self.checkpoint_force_download),
norm_tag=self.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.n_action_steps = int(metadata["n_action_steps"])
if not self.setup_type and metadata.get("setup_type") is not None:
self.setup_type = str(metadata["setup_type"])
if not self.control_mode and metadata.get("control_mode") is not None:
self.control_mode = str(metadata["control_mode"])
def saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or self.action_mode
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
requested_mode = self.inference_action_mode
if requested_mode is None:
return
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def validate_checkpoint_action_mode(
self,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
if self.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def resolve_inference_action_mode(
self,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode is None:
requested_mode = self.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,9 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""MolmoAct2 policy for LeRobot.
MolmoAct2 is a VLM-based robotics policy from Allen AI that combines a
Molmo vision-language backbone with a per-layer flow-matching action expert
for continuous action generation, plus an optional discrete action token
head. This module wraps the vendored HF model implementation
(``molmoact2_hf_model/``) into the LeRobot ``PreTrainedPolicy`` interface.
Paper: https://allenai.org/blog/molmoact2
Code: https://github.com/allenai/molmoact2
"""
from __future__ import annotations
import json
import logging
import os
import types
from collections import deque
@@ -35,13 +46,58 @@ from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from ..rtc.modeling_rtc import RTCProcessor
from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location
from .configuration_molmoact2 import MolmoAct2Config
logger = logging.getLogger(__name__)
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
"""Resolve a checkpoint path to a local directory, downloading from Hub if needed."""
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
from pathlib import Path
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
from huggingface_hub import snapshot_download
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _torch_dtype(dtype: str) -> torch.dtype:
"""Convert a dtype name string to a torch.dtype."""
if dtype == "float32":
return torch.float32
if dtype == "bfloat16":
return torch.bfloat16
if dtype == "float16":
return torch.float16
raise ValueError(f"Unsupported dtype: {dtype}")
if TYPE_CHECKING or _transformers_available:
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from .hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config
from .hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
from .molmoact2_hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config
from .molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
else:
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"
@@ -49,7 +105,7 @@ else:
MolmoAct2ForConditionalGeneration = None
if TYPE_CHECKING or (_transformers_available and _scipy_available):
from .hf_model.action_tokenizer import UniversalActionProcessor
from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor
else:
UniversalActionProcessor = None
@@ -70,6 +126,156 @@ _MODEL_INPUT_KEYS = {
}
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
"""Read per-tag metadata from the checkpoint's ``norm_stats.json``."""
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
from contextlib import suppress
from pathlib import Path
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
def _apply_norm_tag_metadata(config: MolmoAct2Config) -> None:
"""Populate config fields from the checkpoint's norm-tag metadata."""
if not str(config.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
config.checkpoint_path,
revision=config.checkpoint_revision,
force_download=bool(config.checkpoint_force_download),
norm_tag=config.norm_tag,
)
if metadata.get("action_horizon") is not None:
config.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
config.n_action_steps = int(metadata["n_action_steps"])
if not config.setup_type and metadata.get("setup_type") is not None:
config.setup_type = str(metadata["setup_type"])
if not config.control_mode and metadata.get("control_mode") is not None:
config.control_mode = str(metadata["control_mode"])
def _saved_policy_action_mode(config: MolmoAct2Config) -> str | None:
"""Read the action mode from a LeRobot-saved checkpoint's ``config.json``."""
from pathlib import Path
pretrained_path = getattr(config, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def _training_action_mode(config: MolmoAct2Config, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or config.action_mode
def _validate_inference_action_mode(
config: MolmoAct2Config, saved_policy_action_mode: str | None = None
) -> None:
"""Check that the requested inference mode is compatible with the training mode."""
requested_mode = config.inference_action_mode
if requested_mode is None:
return
training_mode = _training_action_mode(config, saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def _validate_checkpoint_action_mode(
config: MolmoAct2Config,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
"""Check that the checkpoint's action mode is compatible with the config."""
if config.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if config.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def _resolve_inference_action_mode(
config: MolmoAct2Config,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
"""Resolve the final inference action mode, validating compatibility."""
training_mode = _training_action_mode(config, saved_policy_action_mode)
if requested_mode is None:
requested_mode = config.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location: str) -> None:
index_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_INDEX_NAME)
single_file_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_NAME)
@@ -103,16 +309,6 @@ def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location
)
def _torch_dtype(dtype: str) -> torch.dtype:
if dtype == "float32":
return torch.float32
if dtype == "bfloat16":
return torch.bfloat16
if dtype == "float16":
return torch.float16
raise ValueError(f"Unsupported dtype: {dtype}")
def _sample_beta_timesteps(
*,
batch_size: int,
@@ -136,7 +332,180 @@ def _sample_beta_timesteps(
return time_offset + scale * samples
def _mask_discrete_action_spans(
*,
input_ids: Tensor,
mask: Tensor,
start_token_id: int | None,
end_token_id: int | None,
) -> Tensor:
if start_token_id is None or end_token_id is None:
return mask
mask = mask.clone()
for batch_idx in range(input_ids.shape[0]):
row = input_ids[batch_idx]
starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist()
ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist()
end_ptr = 0
for start in starts:
while end_ptr < len(ends) and ends[end_ptr] < start:
end_ptr += 1
if end_ptr >= len(ends):
mask[batch_idx, start:] = False
break
end = int(ends[end_ptr])
mask[batch_idx, start : end + 1] = False
end_ptr += 1
return mask
def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]:
attention_mask = model_inputs.get("attention_mask")
if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()):
model_inputs = dict(model_inputs)
model_inputs.pop("attention_mask", None)
return model_inputs
def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None:
if mask is None:
return None
return (
mask.unsqueeze(1)
.expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1)))
.reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:])
)
def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None:
if action_dim_is_pad is None:
return None
mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
if mask.shape[-1] != target.shape[-1]:
raise ValueError(
f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}."
)
if mask.shape[0] == 1 and target.shape[0] != 1:
mask = mask.expand(target.shape[0], -1)
if mask.shape[0] != target.shape[0]:
raise ValueError(
f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}."
)
while mask.ndim < target.ndim:
mask = mask.unsqueeze(1)
return mask
def _mask_action_dim_tensor(tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
if action_dim_is_pad is None:
return tensor
valid_mask = _action_dim_valid_mask(tensor, action_dim_is_pad)
if valid_mask is None:
return tensor
return tensor.masked_fill(~valid_mask, 0)
def _apply_action_dim_padding_mask(loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
valid_mask = _action_dim_valid_mask(loss, action_dim_is_pad)
if valid_mask is None:
return loss
valid = valid_mask.to(dtype=loss.dtype)
denom = valid.sum(dim=-1).clamp_min(1.0)
return (loss * valid).sum(dim=-1) / denom
def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor:
if action_horizon_is_pad is None:
return loss
valid_action = (
(~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1)
)
return loss * valid_action
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
seed = 0
for idx in range(batch_size):
seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1)
return seed
def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None:
task = batch.get("task")
if task is None:
task = batch.get("observation.language")
if task is None:
return None
if isinstance(task, str):
return (task,)
if isinstance(task, (list, tuple)):
return tuple(str(item) for item in task)
return (str(task),)
def _extract_discrete_token_bins(
generated_ids: list[int],
start_token_id: int,
end_token_id: int,
token_id_to_bin: dict[int, int],
) -> list[int]:
start_idx = None
end_idx = None
for idx, token_id in enumerate(generated_ids):
if token_id == start_token_id:
start_idx = idx
break
if start_idx is not None:
for idx in range(start_idx + 1, len(generated_ids)):
if generated_ids[idx] == end_token_id:
end_idx = idx
break
span_start = 0 if start_idx is None else start_idx + 1
span_end = len(generated_ids) if end_idx is None else end_idx
return [
int(token_id_to_bin[token_id])
for token_id in generated_ids[span_start:span_end]
if token_id in token_id_to_bin
]
def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor:
if weights is None:
return values.mean()
weights = weights.to(device=values.device, dtype=values.dtype)
return torch.dot(values, weights) / weights.sum().clamp_min(1.0)
def _weighted_per_example(
values: Tensor,
weights: Tensor | None,
example_indices: Tensor,
batch_size: int,
) -> Tensor:
values = values.float()
if weights is None:
weights = torch.ones_like(values)
else:
weights = weights.to(device=values.device, dtype=values.dtype)
loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
loss_sum.scatter_add_(0, example_indices, values * weights)
weight_sum.scatter_add_(0, example_indices, weights)
global_weight_sum = weight_sum.sum().clamp_min(1.0)
return loss_sum * float(batch_size) / global_weight_sum
class MolmoAct2Policy(PreTrainedPolicy):
"""MolmoAct2 policy wrapping the vendored HF model for LeRobot.
Supports three training modes via ``config.action_mode``:
``"continuous"`` (flow-matching only), ``"discrete"`` (autoregressive
token prediction only), or ``"both"`` (joint loss). At inference,
``config.inference_action_mode`` selects which head generates actions.
"""
config_class = MolmoAct2Config
name = "molmoact2"
@@ -149,10 +518,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
**kwargs,
):
super().__init__(config, *inputs, **kwargs)
self.config.apply_norm_tag_metadata()
_apply_norm_tag_metadata(self.config)
self.config.validate_features()
del inputs, kwargs, dataset_stats, dataset_meta
self._checkpoint_action_mode = self.config.saved_policy_action_mode()
self._checkpoint_action_mode = _saved_policy_action_mode(self.config)
self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps)
self._rollout_action_generator: torch.Generator | None = None
self._rollout_task_key: tuple[Any, ...] | None = None
@@ -160,7 +529,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
self.rtc_processor: RTCProcessor | None = None
self.action_tokenizer: Any | None = None
self._load_hf_model()
self.config.validate_inference_action_mode(self._checkpoint_action_mode)
_validate_inference_action_mode(self.config, self._checkpoint_action_mode)
if self.config.enable_lora_vlm:
self._apply_lora_adapters()
self.init_rtc_processor()
@@ -212,7 +581,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
"`policy.checkpoint_force_download=true` after the updated files are pushed."
)
checkpoint_action_mode = str(self.model.config.action_mode)
self.config.validate_checkpoint_action_mode(
_validate_checkpoint_action_mode(
self.config,
checkpoint_action_mode,
has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)),
)
@@ -226,6 +596,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
self.train(self.training)
def reset(self) -> None:
"""Clear the action queue and rollout generator between episodes."""
self._action_queue = deque(maxlen=self.config.n_action_steps)
self._rollout_action_generator = None
@@ -334,6 +705,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
param.requires_grad = False
def get_optim_params(self) -> list[dict[str, Any]]:
"""Return optimizer param groups with per-component learning rates."""
vit_params: list[Tensor] = []
connector_params: list[Tensor] = []
action_expert_params: list[Tensor] = []
@@ -419,33 +791,6 @@ class MolmoAct2Policy(PreTrainedPolicy):
return int(value)
raise RuntimeError("MolmoAct2 could not resolve an action generation horizon.")
@staticmethod
def _mask_discrete_action_spans(
*,
input_ids: Tensor,
mask: Tensor,
start_token_id: int | None,
end_token_id: int | None,
) -> Tensor:
if start_token_id is None or end_token_id is None:
return mask
mask = mask.clone()
for batch_idx in range(input_ids.shape[0]):
row = input_ids[batch_idx]
starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist()
ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist()
end_ptr = 0
for start in starts:
while end_ptr < len(ends) and ends[end_ptr] < start:
end_ptr += 1
if end_ptr >= len(ends):
mask[batch_idx, start:] = False
break
end = int(ends[end_ptr])
mask[batch_idx, start : end + 1] = False
end_ptr += 1
return mask
def _encoder_attention_mask_for_action_expert(
self,
*,
@@ -470,21 +815,13 @@ class MolmoAct2Policy(PreTrainedPolicy):
eos_token_id = getattr(self.model.config, "eos_token_id", None)
if eos_token_id is not None:
mask &= input_ids != int(eos_token_id)
return self._mask_discrete_action_spans(
return _mask_discrete_action_spans(
input_ids=input_ids,
mask=mask,
start_token_id=getattr(self.model.config, "action_start_token_id", None),
end_token_id=getattr(self.model.config, "action_end_token_id", None),
)
@staticmethod
def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]:
attention_mask = model_inputs.get("attention_mask")
if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()):
model_inputs = dict(model_inputs)
model_inputs.pop("attention_mask", None)
return model_inputs
def _load_discrete_action_tokenizer(self) -> Any:
if self.action_tokenizer is None:
require_package("transformers", extra="molmoact2")
@@ -498,27 +835,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
return self.action_tokenizer
def _resolve_inference_action_mode(self, requested_mode: str | None) -> str:
return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode)
@staticmethod
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
seed = 0
for idx in range(batch_size):
seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1)
return seed
@staticmethod
def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None:
task = batch.get("task")
if task is None:
task = batch.get("observation.language")
if task is None:
return None
if isinstance(task, str):
return (task,)
if isinstance(task, (list, tuple)):
return tuple(str(item) for item in task)
return (str(task),)
return _resolve_inference_action_mode(self.config, requested_mode, self._checkpoint_action_mode)
def _rollout_generator_for_inputs(
self,
@@ -532,7 +849,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
if self._rollout_action_generator is not None:
return self._rollout_action_generator
task_signature = self._rollout_task_signature(batch)
task_signature = _rollout_task_signature(batch)
if task_signature != self._rollout_task_key:
self._rollout_task_key = task_signature
self._rollout_index_for_task = 0
@@ -545,72 +862,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
device if device.type == "cuda" and torch.cuda.is_available() else torch.device("cpu")
)
generator = torch.Generator(device=generator_device)
generator.manual_seed(self._combine_rollout_seeds(first_seed, batch_size))
generator.manual_seed(_combine_rollout_seeds(first_seed, batch_size))
self._rollout_action_generator = generator
return generator
@staticmethod
def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None:
if mask is None:
return None
return (
mask.unsqueeze(1)
.expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1)))
.reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:])
)
@staticmethod
def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None:
if action_dim_is_pad is None:
return None
mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
if mask.shape[-1] != target.shape[-1]:
raise ValueError(
f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}."
)
if mask.shape[0] == 1 and target.shape[0] != 1:
mask = mask.expand(target.shape[0], -1)
if mask.shape[0] != target.shape[0]:
raise ValueError(
f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}."
)
while mask.ndim < target.ndim:
mask = mask.unsqueeze(1)
return mask
@classmethod
def _mask_action_dim_tensor(cls, tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
if not cls._mask_enabled_static(action_dim_is_pad):
return tensor
valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad)
if valid_mask is None:
return tensor
return tensor.masked_fill(~valid_mask, 0)
@staticmethod
def _mask_enabled_static(action_dim_is_pad: Tensor | None) -> bool:
return action_dim_is_pad is not None
@classmethod
def _apply_action_dim_padding_mask(cls, loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
valid_mask = cls._action_dim_valid_mask(loss, action_dim_is_pad)
if valid_mask is None:
return loss
valid = valid_mask.to(dtype=loss.dtype)
denom = valid.sum(dim=-1).clamp_min(1.0)
return (loss * valid).sum(dim=-1) / denom
@staticmethod
def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor:
if action_horizon_is_pad is None:
return loss
valid_action = (
(~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1)
)
return loss * valid_action
def _prepare_flow_matching_tensors(
self,
*,
@@ -649,7 +904,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
)
if self.config.mask_action_dim_padding:
actions = self._mask_action_dim_tensor(actions, action_dim_is_pad)
actions = _mask_action_dim_tensor(actions, action_dim_is_pad)
expected_noise_shape = (batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2])
if noise is None:
@@ -661,7 +916,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
f"flow noise must have shape {expected_noise_shape}, got {tuple(noise.shape)}."
)
if self.config.mask_action_dim_padding:
noise = self._mask_action_dim_tensor(noise, action_dim_is_pad)
noise = _mask_action_dim_tensor(noise, action_dim_is_pad)
t_broadcast = timesteps.view(batch_size, num_flow_timesteps, 1, 1)
actions_expanded = actions.unsqueeze(1).expand(-1, num_flow_timesteps, -1, -1)
@@ -789,7 +1044,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
valid_action = None
if action_attention_mask is not None:
valid_action = action_attention_mask.to(device=device, dtype=actions.dtype).unsqueeze(-1)
valid_action = self._expand_mask(valid_action, num_flow_timesteps)
valid_action = _expand_mask(valid_action, num_flow_timesteps)
rope_cache = None
if len(action_expert.blocks) > 0 and action_expert.blocks[0].self_attn.rope is not None:
@@ -804,14 +1059,14 @@ class MolmoAct2Policy(PreTrainedPolicy):
batch_size,
actions.dtype,
)
cross_mask = self._expand_mask(cross_mask, num_flow_timesteps)
cross_mask = _expand_mask(cross_mask, num_flow_timesteps)
self_mask = action_expert._build_self_attention_mask(
action_attention_mask,
actions.shape[1],
device,
actions.dtype,
)
self_mask = self._expand_mask(self_mask, num_flow_timesteps)
self_mask = _expand_mask(self_mask, num_flow_timesteps)
conditioning = self._action_time_conditioning(action_expert, timesteps_flat)
action_hidden = action_expert.action_embed(xt_flat)
@@ -871,8 +1126,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
if k_norm is not None:
k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2)
if num_flow_timesteps != 1:
k_ctx = self._expand_mask(k_ctx, num_flow_timesteps)
v_ctx = self._expand_mask(v_ctx, num_flow_timesteps)
k_ctx = _expand_mask(k_ctx, num_flow_timesteps)
v_ctx = _expand_mask(v_ctx, num_flow_timesteps)
next_action_hidden = action_block(
layer_action_hidden,
@@ -912,9 +1167,9 @@ class MolmoAct2Policy(PreTrainedPolicy):
)
loss = F.mse_loss(pred_velocity, target_velocity, reduction="none")
loss = self._apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad"))
loss = _apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad"))
if self.config.mask_action_dim_padding:
loss = self._apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad"))
loss = _apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad"))
loss = loss.reshape(batch_size, -1).mean(dim=1)
if reduction == "mean":
loss = loss.mean()
@@ -933,32 +1188,6 @@ class MolmoAct2Policy(PreTrainedPolicy):
example_weights[nonempty] = 2.0 / torch.sqrt(token_counts[nonempty])
return example_weights[:, None].expand_as(valid_positions)[valid_positions].to(dtype=torch.float32)
@staticmethod
def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor:
if weights is None:
return values.mean()
weights = weights.to(device=values.device, dtype=values.dtype)
return torch.dot(values, weights) / weights.sum().clamp_min(1.0)
@staticmethod
def _weighted_per_example(
values: Tensor,
weights: Tensor | None,
example_indices: Tensor,
batch_size: int,
) -> Tensor:
values = values.float()
if weights is None:
weights = torch.ones_like(values)
else:
weights = weights.to(device=values.device, dtype=values.dtype)
loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
loss_sum.scatter_add_(0, example_indices, values * weights)
weight_sum.scatter_add_(0, example_indices, weights)
global_weight_sum = weight_sum.sum().clamp_min(1.0)
return loss_sum * float(batch_size) / global_weight_sum
def _discrete_loss_from_backbone_outputs(
self,
batch: dict[str, Tensor],
@@ -992,56 +1221,28 @@ class MolmoAct2Policy(PreTrainedPolicy):
token_weights = self._discrete_token_weights(valid_positions)
if reduction == "none":
example_indices = valid_positions.nonzero(as_tuple=False)[:, 0].to(device=hidden_states.device)
ce_loss = self._weighted_per_example(
ce_loss = _weighted_per_example(
token_ce_loss,
token_weights,
example_indices,
int(labels.shape[0]),
)
else:
ce_loss = self._weighted_mean(token_ce_loss, token_weights)
ce_loss = _weighted_mean(token_ce_loss, token_weights)
if not self.config.softmax_auxiliary_loss:
return ce_loss, None
if reduction == "none":
z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_per_example(
z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_per_example(
log_z.pow(2),
token_weights,
example_indices,
int(labels.shape[0]),
)
else:
z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_mean(
log_z.pow(2), token_weights
)
z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_mean(log_z.pow(2), token_weights)
return ce_loss, z_loss
@staticmethod
def _extract_discrete_token_bins(
generated_ids: list[int],
start_token_id: int,
end_token_id: int,
token_id_to_bin: dict[int, int],
) -> list[int]:
start_idx = None
end_idx = None
for idx, token_id in enumerate(generated_ids):
if token_id == start_token_id:
start_idx = idx
break
if start_idx is not None:
for idx in range(start_idx + 1, len(generated_ids)):
if generated_ids[idx] == end_token_id:
end_idx = idx
break
span_start = 0 if start_idx is None else start_idx + 1
span_end = len(generated_ids) if end_idx is None else end_idx
return [
int(token_id_to_bin[token_id])
for token_id in generated_ids[span_start:span_end]
if token_id in token_id_to_bin
]
def _action_token_id_to_bin(self) -> dict[int, int]:
method = getattr(self.model, "_action_token_id_to_bin", None)
if callable(method):
@@ -1179,7 +1380,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
chunks: list[Tensor] = []
for token_row in generated_token_ids:
generated_ids = [int(token_id) for token_id in token_row.detach().cpu().tolist()]
discrete_token_ids = self._extract_discrete_token_bins(
discrete_token_ids = _extract_discrete_token_bins(
generated_ids,
int(self.model.config.action_start_token_id),
int(self.model.config.action_end_token_id),
@@ -1218,7 +1419,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
model_inputs: dict[str, Tensor],
action_dim: int,
) -> Tensor:
model_inputs = self._drop_trivial_attention_mask(model_inputs)
model_inputs = _drop_trivial_attention_mask(model_inputs)
max_steps = self._discrete_generation_max_steps()
static_cache, attention_bias = self._make_discrete_ar_graph_decode_inputs(
model_inputs,
@@ -1294,7 +1495,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
generator=generator,
)
if self.config.mask_action_dim_padding:
trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad)
trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad)
action_context = action_expert.prepare_context(
encoder_kv_states=encoder_kv_states,
@@ -1327,7 +1528,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
modulation=step_modulation,
)
if mask_enabled:
velocity = self._mask_action_dim_tensor(velocity, action_dim_is_pad)
velocity = _mask_action_dim_tensor(velocity, action_dim_is_pad)
return velocity
if self._rtc_enabled():
@@ -1352,7 +1553,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
trajectory = trajectory + dt * velocity
if mask_enabled:
trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad)
trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=float(flow_timestep[0].item()), x_t=trajectory, v_t=velocity)
@@ -1363,6 +1564,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
batch: dict[str, Tensor],
reduction: str = "mean",
) -> tuple[Tensor, dict[str, Any]]:
"""Compute training loss (flow-matching and/or discrete token loss)."""
if reduction not in {"mean", "none"}:
raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.")
model_inputs = self._model_inputs(batch)
@@ -1422,6 +1624,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Generate an action chunk via continuous flow matching or discrete AR decoding."""
if "action_mode" in kwargs:
raise TypeError(
"MolmoAct2 predict_action_chunk got unexpected keyword argument 'action_mode'; "
@@ -1476,6 +1679,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Pop one action step from the queue, regenerating the chunk when empty."""
if self._rtc_enabled():
raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk")
self.eval()
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,23 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
import logging
import os
from pathlib import Path
from typing import ClassVar
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
from ..modeling_molmoact2 import _hf_token
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
logger = logging.getLogger(__name__)
def _resolve_tokenizer_location(
@@ -42,6 +36,8 @@ def _resolve_tokenizer_location(
local_path = Path(str(tokenizer_path)).expanduser()
if local_path.exists():
return str(local_path)
from huggingface_hub import snapshot_download
return snapshot_download(
repo_id=str(tokenizer_path),
repo_type="model",
@@ -134,9 +130,8 @@ class UniversalActionProcessor(ProcessorMixin):
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
except Exception:
logger.warning("Error decoding tokens: %s", token, exc_info=True)
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
MolmoAct2 configuration
"""
from typing import Optional, Any
from typing import Any
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,33 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Image processor class for MolmoAct2"""
from typing import Optional, Union
import numpy as np
import einops
import numpy as np
import torch
import torchvision.transforms
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.image_transforms import convert_to_rgb
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
valid_images,
to_numpy_array,
valid_images,
)
from transformers.image_transforms import convert_to_rgb
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
@@ -73,8 +66,8 @@ def resize_image(
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
assert image.dtype == torch.uint8, (
f"SigLIP expects float images or uint8 images, but got {image.dtype}"
)
in_min = 0.0
in_max = 255.0
@@ -96,7 +89,6 @@ def resize_image(
def select_tiling(h, w, patch_size, max_num_crops):
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_crops + 1):
for j in range(1, max_num_crops + 1):
@@ -406,13 +398,17 @@ class MolmoAct2ImageProcessor(BaseImageProcessor):
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
max_crops: int = 8,
overlap_margins: list[int] = [4, 4],
overlap_margins: list[int] | None = None,
crop_mode: str = "overlap-and-resize-c2",
patch_size: int = 14,
pooling_size: list[int] = [2, 2],
pooling_size: list[int] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if overlap_margins is None:
overlap_margins = [4, 4]
if pooling_size is None:
pooling_size = [2, 2]
size = size if size is not None else {"height": 378, "width": 378}
size = get_size_dict(size, default_to_square=True)
self.size = size
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Inference utilities for MolmoAct2"""
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any
import torch
from torch.nn import functional as F
from torch.nn import functional as F # noqa: N812
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
@@ -679,7 +676,7 @@ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
def _copy_context_(dst: Any, src: Any) -> None:
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts, strict=False):
dst_k.copy_(src_k)
dst_v.copy_(src_v)
if src.cross_mask is not None:
@@ -689,7 +686,7 @@ def _copy_context_(dst: Any, src: Any) -> None:
if src.valid_action is not None:
dst.valid_action.copy_(src.valid_action)
if src.rope_cache is not None:
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache, strict=False):
dst_tensor.copy_(src_tensor)
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,24 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Modeling code for MolmoAct2"""
# ruff: noqa: N806
import json
import math
import os
import re
from collections.abc import Callable, Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Optional
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn import functional as F # noqa: N812
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
@@ -647,7 +646,7 @@ class ActionExpert(nn.Module):
f"got {len(encoder_kv_states)}."
)
kv_contexts = []
for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states):
for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states, strict=False):
k_ctx = self._project_kv_tensor(k_in, self.context_k_proj)
v_ctx = self._project_kv_tensor(v_in, self.context_v_proj)
k_norm = block.cross_attn.k_norm
@@ -732,7 +731,7 @@ class ActionExpert(nn.Module):
timesteps: Sequence[torch.Tensor],
) -> Sequence[ActionExpertStepModulation]:
cache = []
for idx, step_t in enumerate(timesteps):
for _idx, step_t in enumerate(timesteps):
conditioning = self._time_conditioning(step_t)
block_modulations = []
for block in self.blocks:
@@ -786,8 +785,8 @@ class ActionExpert(nn.Module):
x = self.action_embed(actions)
if context.valid_action is not None:
x = x * context.valid_action
for idx, (block, kv_context, block_modulation) in enumerate(
zip(self.blocks, context.kv_contexts, block_modulations)
for _idx, (block, kv_context, block_modulation) in enumerate(
zip(self.blocks, context.kv_contexts, block_modulations, strict=False)
):
x = block(
x,
@@ -2874,7 +2873,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel):
depth_mask=depth_mask,
encoder_attention_mask=encoder_attention_mask,
)
for gate, source in zip(gate_head, sources)
for gate, source in zip(gate_head, sources, strict=False)
]
return gates, depth_mask
gate = self._depth_gate_from_source(
@@ -4458,7 +4457,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
```python
>>> from PIL import Image
>>> import requests
>>> from lerobot.policies.molmoact2.hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
>>> from lerobot.policies.molmoact2.molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
>>> from lerobot.policies.molmoact2.processor_molmoact2 import _load_local_molmoact2_processor
>>> model = MolmoAct2ForConditionalGeneration.from_pretrained("...")
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,45 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
Processor class for MolmoAct2.
"""
from typing import Optional, Union
import dataclasses
import numpy as np
from transformers import AutoTokenizer
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
Unpack,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.feature_extraction_utils import BatchFeature
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
from transformers import AutoTokenizer
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
from .image_processing_molmoact2 import MolmoAct2ImageProcessor, MolmoAct2ImagesKwargs
from .video_processing_molmoact2 import MolmoAct2VideoProcessor, MolmoAct2VideoProcessorKwargs
logger = logging.get_logger(__name__)
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
IM_START_TOKEN = f"<im_start>"
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
FRAME_START_TOKEN = f"<frame_start>"
IM_END_TOKEN = f"<im_end>"
FRAME_END_TOKEN = f"<frame_end>"
IM_COL_TOKEN = f"<im_col>"
IMAGE_PATCH_TOKEN = "<im_patch>" # nosec B105 # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = "<im_low>" # nosec B105 # Where to insert low-res tokens
IM_START_TOKEN = "<im_start>" # nosec B105
LOW_RES_IMAGE_START_TOKEN = "<low_res_im_start>" # nosec B105
FRAME_START_TOKEN = "<frame_start>" # nosec B105
IM_END_TOKEN = "<im_end>" # nosec B105
FRAME_END_TOKEN = "<frame_end>" # nosec B105
IM_COL_TOKEN = "<im_col>" # nosec B105
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
@@ -224,7 +216,7 @@ class MolmoAct2Processor(ProcessorMixin):
input_ids = input_ids[None, :]
attention_mask = attention_mask[None, :]
B, S = input_ids.shape
B, S = input_ids.shape # noqa: N806
# Handle zero-length sequence
if S == 0:
@@ -364,7 +356,7 @@ class MolmoAct2Processor(ProcessorMixin):
assert num_videos in {0, 1}, "At most one video is supported for now"
video_grids_i = video_grids[index : index + num_videos]
metadata_i = video_metadata[index : index + num_videos]
for video_grid, metadata in zip(video_grids_i, metadata_i):
for video_grid, metadata in zip(video_grids_i, metadata_i, strict=False):
video_string = self.get_video_string(
video_grid,
metadata.timestamps,
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,25 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Video processor class for MolmoAct2"""
from functools import partial
import os
import warnings
from collections.abc import Callable
from contextlib import redirect_stdout
from functools import partial
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Union
from collections.abc import Callable
import einops
import numpy as np
import requests
import einops
import torch
import torchvision.transforms
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
@@ -41,27 +37,24 @@ from transformers.image_utils import (
SizeDict,
validate_kwargs,
)
from transformers.video_utils import (
VideoInput,
is_valid_video,
make_batched_videos,
make_batched_metadata,
VideoMetadata,
)
from transformers.processing_utils import Unpack, VideosKwargs
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import (
TensorType,
is_av_available,
is_decord_available,
is_torchcodec_available,
is_yt_dlp_available,
TensorType,
logging,
to_numpy,
)
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.video_utils import (
VideoInput,
VideoMetadata,
is_valid_video,
make_batched_metadata,
make_batched_videos,
)
logger = logging.get_logger(__name__)
@@ -102,8 +95,8 @@ def resize_image(
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
assert image.dtype == torch.uint8, (
f"SigLIP expects float images or uint8 images, but got {image.dtype}"
)
in_min = 0.0
in_max = 255.0
@@ -548,9 +541,8 @@ def get_target_fps(
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if "uniform" in frame_sample_mode:
if num_frames_sampled_at_fps > max_frames:
break
if "uniform" in frame_sample_mode and num_frames_sampled_at_fps > max_frames:
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
@@ -779,13 +771,15 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
elif is_torchcodec_available():
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `torchcodec`."
"Falling back to `torchcodec`.",
stacklevel=2,
)
backend = "torchcodec"
else:
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `PyAV`."
"Falling back to `PyAV`.",
stacklevel=2,
)
backend = "pyav"
@@ -795,7 +789,8 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
*[
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
for x in video_url_or_urls
]
],
strict=False,
)
)
else:
@@ -821,7 +816,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
sampled_videos = []
sampled_metadata = []
for video, metadata in zip(videos, video_metadata):
for video, metadata in zip(videos, video_metadata, strict=False):
indices = sample_indices_fn(metadata=metadata)
metadata.frames_indices = indices
sampled_videos.append(video[indices])
@@ -985,11 +980,11 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
pixel_values_videos = np.concatenate(batch_crops, 0)
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
data = dict(
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
data = {
"pixel_values_videos": pixel_values_videos,
"video_token_pooling": video_token_pooling,
"video_grids": video_grids,
}
return BatchFeature(data, tensor_type=return_tensors)
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""MolmoAct2 pre/post processing pipeline.
Builds the multimodal prompt (images, discretised state, task text),
tokenises it via the vendored MolmoAct2 processor, and handles quantile
normalisation with optional per-dimension gripper masking.
"""
from __future__ import annotations
import json
import os
import logging
import math
import re
from contextlib import suppress
from copy import deepcopy
@@ -27,7 +33,6 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from huggingface_hub import snapshot_download
from torch import Tensor
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
@@ -54,14 +59,71 @@ from lerobot.utils.constants import (
)
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from .configuration_molmoact2 import MolmoAct2Config, infer_molmoact2_max_sequence_length
from .configuration_molmoact2 import MolmoAct2Config
from .modeling_molmoact2 import _hf_token, _resolve_checkpoint_location
logger = logging.getLogger(__name__)
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
if TYPE_CHECKING or _transformers_available:
from transformers import Qwen2Tokenizer
from .hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor
from .hf_model.processing_molmoact2 import MolmoAct2Processor
from .hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor
from .molmoact2_hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor
from .molmoact2_hf_model.processing_molmoact2 import MolmoAct2Processor
from .molmoact2_hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor
else:
Qwen2Tokenizer = None
MolmoAct2ImageProcessor = None
@@ -69,7 +131,7 @@ else:
MolmoAct2VideoProcessor = None
if TYPE_CHECKING or (_transformers_available and _scipy_available):
from .hf_model.action_tokenizer import UniversalActionProcessor
from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor
else:
UniversalActionProcessor = None
@@ -97,32 +159,6 @@ _QUESTION_PREFIX_PATTERNS = tuple(
)
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_stats_for_tag(
checkpoint_path: str,
*,
+35 -48
View File
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -35,16 +33,16 @@ pytest.importorskip("scipy")
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies import get_policy_class, make_policy_config
from lerobot.policies.molmoact2 import (
configuration_molmoact2 as molmoact2_config,
modeling_molmoact2 as molmoact2_modeling,
processor_molmoact2 as molmoact2_processor,
)
from lerobot.policies.molmoact2.configuration_molmoact2 import (
MolmoAct2Config,
MolmoAct2CosineDecayWithWarmupSchedulerConfig,
infer_molmoact2_max_sequence_length,
from lerobot.policies.molmoact2.configuration_molmoact2 import MolmoAct2Config
from lerobot.policies.molmoact2.modeling_molmoact2 import (
MolmoAct2Policy,
_apply_action_chunk_padding_mask,
_apply_action_dim_padding_mask,
_combine_rollout_seeds,
)
from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy
from lerobot.policies.molmoact2.processor_molmoact2 import (
MolmoAct2ClampNormalizedProcessorStep,
MolmoAct2MaskedNormalizerProcessorStep,
@@ -53,6 +51,7 @@ from lerobot.policies.molmoact2.processor_molmoact2 import (
_add_gripper_masks_to_stats,
_build_discrete_state_string,
_normalize_question_text,
infer_molmoact2_max_sequence_length,
make_molmoact2_pre_post_processors,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@@ -71,34 +70,38 @@ def test_molmoact2_policy_registration():
assert cfg.per_episode_seed is False
assert cfg.eval_seed is None
assert cfg.normalize_language is True
assert cfg.get_scheduler_preset().num_decay_steps is None
assert cfg.get_scheduler_preset().num_decay_steps == 100_000
assert cfg.action_delta_indices == list(range(cfg.chunk_size))
assert get_policy_class("molmoact2") is MolmoAct2Policy
def test_molmoact2_checkpoint_download_ignores_remote_python(monkeypatch):
import huggingface_hub
download_kwargs = {}
def fake_snapshot_download(**kwargs):
download_kwargs.update(kwargs)
return "/tmp/downloaded-molmoact2"
monkeypatch.setattr(molmoact2_config, "snapshot_download", fake_snapshot_download)
monkeypatch.setattr(huggingface_hub, "snapshot_download", fake_snapshot_download)
checkpoint_location = molmoact2_config._resolve_checkpoint_location("allenai/MolmoAct2")
checkpoint_location = molmoact2_modeling._resolve_checkpoint_location("allenai/MolmoAct2")
assert checkpoint_location == "/tmp/downloaded-molmoact2"
assert download_kwargs["ignore_patterns"] == ["*.py", "*.pyc", "__pycache__/*"]
def test_molmoact2_scheduler_decay_steps_auto_match_training_steps():
def test_molmoact2_scheduler_auto_scales_to_training_steps():
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
param = torch.nn.Parameter(torch.ones(()))
optimizer = torch.optim.AdamW([param], lr=0.001)
config = MolmoAct2CosineDecayWithWarmupSchedulerConfig(
config = CosineDecayWithWarmupSchedulerConfig(
peak_lr=0.01,
decay_lr=0.001,
num_warmup_steps=10,
num_decay_steps=None,
num_decay_steps=100_000,
)
scheduler = config.build(optimizer, num_training_steps=100)
@@ -123,9 +126,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
batch_size=3,
device=torch.device("cpu"),
)
expected_first = torch.Generator().manual_seed(
MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3)
)
expected_first = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3))
assert torch.allclose(torch.rand(4, generator=first), torch.rand(4, generator=expected_first))
policy.reset()
@@ -134,9 +135,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
batch_size=3,
device=torch.device("cpu"),
)
expected_second = torch.Generator().manual_seed(
MolmoAct2Policy._combine_rollout_seeds(first_seed=1003, batch_size=3)
)
expected_second = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1003, batch_size=3))
assert torch.allclose(torch.rand(4, generator=second), torch.rand(4, generator=expected_second))
policy.reset()
@@ -145,9 +144,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
batch_size=3,
device=torch.device("cpu"),
)
expected_new_task = torch.Generator().manual_seed(
MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3)
)
expected_new_task = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3))
assert torch.allclose(torch.rand(4, generator=new_task), torch.rand(4, generator=expected_new_task))
@@ -537,36 +534,26 @@ def test_train_action_expert_only_requires_continuous_action_mode():
def test_molmoact2_sequence_length_is_inferred_from_fixed_token_budget():
cfg = MolmoAct2Config(
action_mode="both",
chunk_size=10,
n_action_steps=10,
image_keys=["observation.images.image", "observation.images.wrist_image"],
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,))},
)
assert cfg.max_sequence_length is None
assert cfg.inferred_max_sequence_length() == 640
assert cfg.inferred_max_sequence_length(include_discrete_action=False) == 576
assert (
infer_molmoact2_max_sequence_length(
num_images=2,
state_dim=8,
action_dim=7,
action_horizon=30,
include_discrete_action=True,
num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=True
)
== 640
)
assert (
infer_molmoact2_max_sequence_length(
num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=False
)
== 576
)
assert (
infer_molmoact2_max_sequence_length(
num_images=2, state_dim=8, action_dim=7, action_horizon=30, include_discrete_action=True
)
== 768
)
def test_molmoact2_sequence_length_override_is_preserved():
cfg = MolmoAct2Config(max_sequence_length=1024)
assert cfg.inferred_max_sequence_length(num_images=2, state_dim=8, action_dim=7) == 1024
def test_train_action_expert_only_freezes_non_action_expert_params():
class DummyBackbone(torch.nn.Module):
def __init__(self):
@@ -963,7 +950,7 @@ def test_action_dim_padding_loss_reduces_like_old_trainer():
]
)
reduced = MolmoAct2Policy._apply_action_dim_padding_mask(loss, action_dim_is_pad)
reduced = _apply_action_dim_padding_mask(loss, action_dim_is_pad)
expected = torch.stack(
[
@@ -979,7 +966,7 @@ def test_action_chunk_padding_keeps_old_mean_denominator():
loss = torch.ones(1, 2, 4, 3)
action_horizon_is_pad = torch.tensor([[False, False, True, True]])
masked = MolmoAct2Policy._apply_action_chunk_padding_mask(loss, action_horizon_is_pad)
masked = _apply_action_chunk_padding_mask(loss, action_horizon_is_pad)
assert masked.mean().item() == 0.5