diff --git a/docs/source/molmoact2.mdx b/docs/source/molmoact2.mdx index c6ae24e9e..7927c6694 100644 --- a/docs/source/molmoact2.mdx +++ b/docs/source/molmoact2.mdx @@ -17,7 +17,7 @@ the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2). Install LeRobot with the MolmoAct2 optional dependencies: ```bash -pip install -e ".[molmoact2]" +uv sync --locked --extra molmoact2 ``` To run the models in this repository, you need an NVIDIA GPU. The measurements @@ -46,8 +46,8 @@ The repo has been tested with Ubuntu 22.04. To use MolmoAct2 in a LeRobot training config, set: -```python -policy.type=molmoact2 +```bash +--policy.type=molmoact2 ``` ## Training diff --git a/src/lerobot/policies/molmoact2/README.md b/src/lerobot/policies/molmoact2/README.md index ef419516d..9756785d9 120000 --- a/src/lerobot/policies/molmoact2/README.md +++ b/src/lerobot/policies/molmoact2/README.md @@ -1 +1 @@ -../../../../docs/source/policy_molmoact2_README.md \ No newline at end of file +../../../../docs/source/molmoact2.mdx \ No newline at end of file diff --git a/src/lerobot/policies/molmoact2/__init__.py b/src/lerobot/policies/molmoact2/__init__.py index bfef53bb2..a4e7695c2 100644 --- a/src/lerobot/policies/molmoact2/__init__.py +++ b/src/lerobot/policies/molmoact2/__init__.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/lerobot/policies/molmoact2/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/configuration_molmoact2.py index de2585281..53aefdee6 100644 --- a/src/lerobot/policies/molmoact2/configuration_molmoact2.py +++ b/src/lerobot/policies/molmoact2/configuration_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,16 +14,9 @@ from __future__ import annotations -import json -import math -import os -from contextlib import suppress from dataclasses import dataclass, field -from pathlib import Path from typing import Any -from huggingface_hub import snapshot_download - from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig from lerobot.optim import ( AdamWConfig, @@ -37,146 +28,6 @@ from lerobot.utils.constants import ACTION, OBS_STATE from ..rtc.configuration_rtc import RTCConfig -MOLMOACT2_DEFAULT_NUM_IMAGES = 2 -MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196 -MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80 -MOLMOACT2_TASK_TOKEN_BUDGET = 32 -MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32 -MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64 -MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4 -MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6 -MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95 - - -def _hf_token() -> str | None: - return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") - - -def _resolve_checkpoint_location( - checkpoint_path: str, - *, - revision: str | None = None, - force_download: bool = False, -) -> str: - checkpoint_path = str(checkpoint_path or "").strip() - if not checkpoint_path: - raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") - local_path = Path(checkpoint_path).expanduser() - if local_path.exists(): - return str(local_path) - return snapshot_download( - repo_id=checkpoint_path, - repo_type="model", - revision=revision, - force_download=force_download, - ignore_patterns=["*.py", "*.pyc", "__pycache__/*"], - token=_hf_token(), - ) - - -def _load_hf_norm_metadata_for_tag( - checkpoint_path: str, - *, - revision: str | None, - force_download: bool, - norm_tag: str | None, -) -> dict[str, Any]: - norm_tag = str(norm_tag or "").strip() - if not norm_tag: - return {} - checkpoint_location = Path( - _resolve_checkpoint_location( - checkpoint_path, - revision=revision, - force_download=force_download, - ) - ) - norm_stats_filename = "norm_stats.json" - config_path = checkpoint_location / "config.json" - if config_path.exists(): - with suppress(OSError, json.JSONDecodeError): - norm_stats_filename = str( - json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename - ) - stats_path = checkpoint_location / norm_stats_filename - if not stats_path.exists(): - raise FileNotFoundError( - f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." - ) - payload = json.loads(stats_path.read_text()) - metadata_by_tag = payload.get("metadata_by_tag") - if not isinstance(metadata_by_tag, dict): - raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") - metadata = metadata_by_tag.get(norm_tag) - if not isinstance(metadata, dict): - available = sorted(str(tag) for tag in metadata_by_tag) - raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") - return metadata - - -@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup") -@dataclass -class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig): - """MolmoAct2-local cosine scheduler with optional decay-step auto-match. - - LeRobot's generic cosine scheduler keeps an explicit integer decay length. - For MolmoAct2, leaving num_decay_steps unset means "decay across this run's - training steps"; build() is the first point where num_training_steps is known. - """ - - num_decay_steps: int | None - - def build(self, optimizer, num_training_steps: int): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.peak_lr, - decay_lr=self.decay_lr, - num_warmup_steps=self.num_warmup_steps, - num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps, - ).build(optimizer, num_training_steps=num_training_steps) - - -def _round_up(value: int, multiple: int) -> int: - return int(math.ceil(value / multiple) * multiple) - - -def infer_molmoact2_max_sequence_length( - *, - num_images: int, - state_dim: int, - action_dim: int, - action_horizon: int, - include_discrete_action: bool, -) -> int: - """Infer the padded text/image sequence cap from MolmoAct2's fixed token layout.""" - if num_images < 1: - num_images = MOLMOACT2_DEFAULT_NUM_IMAGES - if state_dim < 0: - state_dim = 0 - if action_dim < 1: - action_dim = 1 - if action_horizon < 1: - action_horizon = 1 - - image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE - prompt_tokens = ( - MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET - + MOLMOACT2_TASK_TOKEN_BUDGET - + state_dim - + MOLMOACT2_SEQUENCE_LENGTH_MARGIN - ) - action_tokens = 0 - if include_discrete_action: - action_tokens_per_step = max( - MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP, - math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM), - ) - action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step - - return _round_up( - image_tokens + prompt_tokens + action_tokens, - MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE, - ) - @PreTrainedConfig.register_subclass("molmoact2") @dataclass @@ -255,7 +106,7 @@ class MolmoAct2Config(PreTrainedConfig): optimizer_grad_clip_norm: float = 1.0 scheduler_warmup_steps: int = 200 - scheduler_decay_steps: int | None = None + scheduler_decay_steps: int = 100_000 scheduler_decay_lr: float = 1e-6 normalization_mapping: dict[str, NormalizationMode] = field( @@ -333,41 +184,6 @@ class MolmoAct2Config(PreTrainedConfig): if self.max_sequence_length is not None and self.max_sequence_length < 1: raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.") - def inferred_max_sequence_length( - self, - *, - num_images: int | None = None, - state_dim: int | None = None, - action_dim: int | None = None, - action_horizon: int | None = None, - include_discrete_action: bool | None = None, - ) -> int: - if self.max_sequence_length is not None: - return int(self.max_sequence_length) - - if num_images is None: - num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES - if state_dim is None: - state_feature = self.robot_state_feature - state_dim = int(state_feature.shape[0]) if state_feature is not None else 0 - if action_dim is None: - action_feature = self.action_feature - action_dim = ( - int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim - ) - if action_horizon is None: - action_horizon = self.chunk_size - if include_discrete_action is None: - include_discrete_action = self.action_mode in {"discrete", "both"} - - return infer_molmoact2_max_sequence_length( - num_images=int(num_images), - state_dim=int(state_dim), - action_dim=int(action_dim), - action_horizon=int(action_horizon), - include_discrete_action=bool(include_discrete_action), - ) - @property def observation_delta_indices(self) -> None: return None @@ -390,7 +206,7 @@ class MolmoAct2Config(PreTrainedConfig): ) def get_scheduler_preset(self) -> LRSchedulerConfig | None: - return MolmoAct2CosineDecayWithWarmupSchedulerConfig( + return CosineDecayWithWarmupSchedulerConfig( peak_lr=self.optimizer_lr, decay_lr=self.scheduler_decay_lr, num_warmup_steps=self.scheduler_warmup_steps, @@ -426,94 +242,3 @@ class MolmoAct2Config(PreTrainedConfig): shape=(self.expected_max_action_dim,), ) self.output_features[ACTION] = action_feature - - def apply_norm_tag_metadata(self) -> None: - if not str(self.norm_tag or "").strip(): - return - metadata = _load_hf_norm_metadata_for_tag( - self.checkpoint_path, - revision=self.checkpoint_revision, - force_download=bool(self.checkpoint_force_download), - norm_tag=self.norm_tag, - ) - if metadata.get("action_horizon") is not None: - self.chunk_size = int(metadata["action_horizon"]) - if metadata.get("n_action_steps") is not None: - self.n_action_steps = int(metadata["n_action_steps"]) - if not self.setup_type and metadata.get("setup_type") is not None: - self.setup_type = str(metadata["setup_type"]) - if not self.control_mode and metadata.get("control_mode") is not None: - self.control_mode = str(metadata["control_mode"]) - - def saved_policy_action_mode(self) -> str | None: - pretrained_path = getattr(self, "pretrained_path", None) - if pretrained_path is None: - return None - config_path = Path(pretrained_path) / "config.json" - if not config_path.exists(): - return None - try: - mode = json.loads(config_path.read_text()).get("action_mode") - except (OSError, json.JSONDecodeError): - return None - if mode in {"continuous", "discrete", "both"}: - return str(mode) - return None - - def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str: - return saved_policy_action_mode or self.action_mode - - def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None: - requested_mode = self.inference_action_mode - if requested_mode is None: - return - training_mode = self.training_action_mode(saved_policy_action_mode) - if requested_mode == "continuous" and training_mode == "discrete": - raise ValueError( - "MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run " - "continuous inference." - ) - if requested_mode == "discrete" and training_mode == "continuous": - raise ValueError( - "MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run " - "discrete inference. Train with action_mode='both' or action_mode='discrete' first." - ) - - def validate_checkpoint_action_mode( - self, - checkpoint_action_mode: str, - *, - has_action_expert: bool, - ) -> None: - if self.action_mode == "both" and checkpoint_action_mode != "both": - raise ValueError( - f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}." - ) - if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}: - raise ValueError( - f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, " - f"got {checkpoint_action_mode!r}." - ) - if self.action_mode in {"continuous", "both"} and not has_action_expert: - raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.") - - def resolve_inference_action_mode( - self, - requested_mode: str | None, - saved_policy_action_mode: str | None = None, - ) -> str: - training_mode = self.training_action_mode(saved_policy_action_mode) - if requested_mode is None: - requested_mode = self.inference_action_mode - if requested_mode is None: - raise ValueError( - "MolmoAct2 inference requires `inference_action_mode` to be set explicitly " - "to either 'continuous' or 'discrete'." - ) - if requested_mode not in {"continuous", "discrete"}: - raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.") - if requested_mode == "continuous" and training_mode == "discrete": - raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.") - if requested_mode == "discrete" and training_mode == "continuous": - raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.") - return requested_mode diff --git a/src/lerobot/policies/molmoact2/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/modeling_molmoact2.py index f86be0904..2cc85ab02 100644 --- a/src/lerobot/policies/molmoact2/modeling_molmoact2.py +++ b/src/lerobot/policies/molmoact2/modeling_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,9 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""MolmoAct2 policy for LeRobot. + +MolmoAct2 is a VLM-based robotics policy from Allen AI that combines a +Molmo vision-language backbone with a per-layer flow-matching action expert +for continuous action generation, plus an optional discrete action token +head. This module wraps the vendored HF model implementation +(``molmoact2_hf_model/``) into the LeRobot ``PreTrainedPolicy`` interface. + +Paper: https://allenai.org/blog/molmoact2 +Code: https://github.com/allenai/molmoact2 +""" + from __future__ import annotations import json +import logging import os import types from collections import deque @@ -35,13 +46,58 @@ from lerobot.utils.constants import ACTION from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package from ..rtc.modeling_rtc import RTCProcessor -from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location +from .configuration_molmoact2 import MolmoAct2Config + +logger = logging.getLogger(__name__) + + +def _hf_token() -> str | None: + return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") + + +def _resolve_checkpoint_location( + checkpoint_path: str, + *, + revision: str | None = None, + force_download: bool = False, +) -> str: + """Resolve a checkpoint path to a local directory, downloading from Hub if needed.""" + checkpoint_path = str(checkpoint_path or "").strip() + if not checkpoint_path: + raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") + from pathlib import Path + + local_path = Path(checkpoint_path).expanduser() + if local_path.exists(): + return str(local_path) + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id=checkpoint_path, + repo_type="model", + revision=revision, + force_download=force_download, + ignore_patterns=["*.py", "*.pyc", "__pycache__/*"], + token=_hf_token(), + ) + + +def _torch_dtype(dtype: str) -> torch.dtype: + """Convert a dtype name string to a torch.dtype.""" + if dtype == "float32": + return torch.float32 + if dtype == "bfloat16": + return torch.bfloat16 + if dtype == "float16": + return torch.float16 + raise ValueError(f"Unsupported dtype: {dtype}") + if TYPE_CHECKING or _transformers_available: from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME - from .hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config - from .hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration + from .molmoact2_hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config + from .molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration else: SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" SAFE_WEIGHTS_NAME = "model.safetensors" @@ -49,7 +105,7 @@ else: MolmoAct2ForConditionalGeneration = None if TYPE_CHECKING or (_transformers_available and _scipy_available): - from .hf_model.action_tokenizer import UniversalActionProcessor + from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor else: UniversalActionProcessor = None @@ -70,6 +126,156 @@ _MODEL_INPUT_KEYS = { } +def _load_hf_norm_metadata_for_tag( + checkpoint_path: str, + *, + revision: str | None, + force_download: bool, + norm_tag: str | None, +) -> dict[str, Any]: + """Read per-tag metadata from the checkpoint's ``norm_stats.json``.""" + norm_tag = str(norm_tag or "").strip() + if not norm_tag: + return {} + from contextlib import suppress + from pathlib import Path + + checkpoint_location = Path( + _resolve_checkpoint_location( + checkpoint_path, + revision=revision, + force_download=force_download, + ) + ) + norm_stats_filename = "norm_stats.json" + config_path = checkpoint_location / "config.json" + if config_path.exists(): + with suppress(OSError, json.JSONDecodeError): + norm_stats_filename = str( + json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename + ) + stats_path = checkpoint_location / norm_stats_filename + if not stats_path.exists(): + raise FileNotFoundError( + f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." + ) + payload = json.loads(stats_path.read_text()) + metadata_by_tag = payload.get("metadata_by_tag") + if not isinstance(metadata_by_tag, dict): + raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") + metadata = metadata_by_tag.get(norm_tag) + if not isinstance(metadata, dict): + available = sorted(str(tag) for tag in metadata_by_tag) + raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") + return metadata + + +def _apply_norm_tag_metadata(config: MolmoAct2Config) -> None: + """Populate config fields from the checkpoint's norm-tag metadata.""" + if not str(config.norm_tag or "").strip(): + return + metadata = _load_hf_norm_metadata_for_tag( + config.checkpoint_path, + revision=config.checkpoint_revision, + force_download=bool(config.checkpoint_force_download), + norm_tag=config.norm_tag, + ) + if metadata.get("action_horizon") is not None: + config.chunk_size = int(metadata["action_horizon"]) + if metadata.get("n_action_steps") is not None: + config.n_action_steps = int(metadata["n_action_steps"]) + if not config.setup_type and metadata.get("setup_type") is not None: + config.setup_type = str(metadata["setup_type"]) + if not config.control_mode and metadata.get("control_mode") is not None: + config.control_mode = str(metadata["control_mode"]) + + +def _saved_policy_action_mode(config: MolmoAct2Config) -> str | None: + """Read the action mode from a LeRobot-saved checkpoint's ``config.json``.""" + from pathlib import Path + + pretrained_path = getattr(config, "pretrained_path", None) + if pretrained_path is None: + return None + config_path = Path(pretrained_path) / "config.json" + if not config_path.exists(): + return None + try: + mode = json.loads(config_path.read_text()).get("action_mode") + except (OSError, json.JSONDecodeError): + return None + if mode in {"continuous", "discrete", "both"}: + return str(mode) + return None + + +def _training_action_mode(config: MolmoAct2Config, saved_policy_action_mode: str | None = None) -> str: + return saved_policy_action_mode or config.action_mode + + +def _validate_inference_action_mode( + config: MolmoAct2Config, saved_policy_action_mode: str | None = None +) -> None: + """Check that the requested inference mode is compatible with the training mode.""" + requested_mode = config.inference_action_mode + if requested_mode is None: + return + training_mode = _training_action_mode(config, saved_policy_action_mode) + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run " + "continuous inference." + ) + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run " + "discrete inference. Train with action_mode='both' or action_mode='discrete' first." + ) + + +def _validate_checkpoint_action_mode( + config: MolmoAct2Config, + checkpoint_action_mode: str, + *, + has_action_expert: bool, +) -> None: + """Check that the checkpoint's action mode is compatible with the config.""" + if config.action_mode == "both" and checkpoint_action_mode != "both": + raise ValueError( + f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}." + ) + if config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}: + raise ValueError( + f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, " + f"got {checkpoint_action_mode!r}." + ) + if config.action_mode in {"continuous", "both"} and not has_action_expert: + raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.") + + +def _resolve_inference_action_mode( + config: MolmoAct2Config, + requested_mode: str | None, + saved_policy_action_mode: str | None = None, +) -> str: + """Resolve the final inference action mode, validating compatibility.""" + training_mode = _training_action_mode(config, saved_policy_action_mode) + if requested_mode is None: + requested_mode = config.inference_action_mode + if requested_mode is None: + raise ValueError( + "MolmoAct2 inference requires `inference_action_mode` to be set explicitly " + "to either 'continuous' or 'discrete'." + ) + if requested_mode not in {"continuous", "discrete"}: + raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.") + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.") + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.") + return requested_mode + + def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location: str) -> None: index_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_INDEX_NAME) single_file_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_NAME) @@ -103,16 +309,6 @@ def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location ) -def _torch_dtype(dtype: str) -> torch.dtype: - if dtype == "float32": - return torch.float32 - if dtype == "bfloat16": - return torch.bfloat16 - if dtype == "float16": - return torch.float16 - raise ValueError(f"Unsupported dtype: {dtype}") - - def _sample_beta_timesteps( *, batch_size: int, @@ -136,7 +332,180 @@ def _sample_beta_timesteps( return time_offset + scale * samples +def _mask_discrete_action_spans( + *, + input_ids: Tensor, + mask: Tensor, + start_token_id: int | None, + end_token_id: int | None, +) -> Tensor: + if start_token_id is None or end_token_id is None: + return mask + mask = mask.clone() + for batch_idx in range(input_ids.shape[0]): + row = input_ids[batch_idx] + starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist() + ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist() + end_ptr = 0 + for start in starts: + while end_ptr < len(ends) and ends[end_ptr] < start: + end_ptr += 1 + if end_ptr >= len(ends): + mask[batch_idx, start:] = False + break + end = int(ends[end_ptr]) + mask[batch_idx, start : end + 1] = False + end_ptr += 1 + return mask + + +def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]: + attention_mask = model_inputs.get("attention_mask") + if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()): + model_inputs = dict(model_inputs) + model_inputs.pop("attention_mask", None) + return model_inputs + + +def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None: + if mask is None: + return None + return ( + mask.unsqueeze(1) + .expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1))) + .reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:]) + ) + + +def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None: + if action_dim_is_pad is None: + return None + mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool) + if mask.ndim == 1: + mask = mask.unsqueeze(0) + if mask.shape[-1] != target.shape[-1]: + raise ValueError( + f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}." + ) + if mask.shape[0] == 1 and target.shape[0] != 1: + mask = mask.expand(target.shape[0], -1) + if mask.shape[0] != target.shape[0]: + raise ValueError( + f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}." + ) + while mask.ndim < target.ndim: + mask = mask.unsqueeze(1) + return mask + + +def _mask_action_dim_tensor(tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: + if action_dim_is_pad is None: + return tensor + valid_mask = _action_dim_valid_mask(tensor, action_dim_is_pad) + if valid_mask is None: + return tensor + return tensor.masked_fill(~valid_mask, 0) + + +def _apply_action_dim_padding_mask(loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: + valid_mask = _action_dim_valid_mask(loss, action_dim_is_pad) + if valid_mask is None: + return loss + valid = valid_mask.to(dtype=loss.dtype) + denom = valid.sum(dim=-1).clamp_min(1.0) + return (loss * valid).sum(dim=-1) / denom + + +def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor: + if action_horizon_is_pad is None: + return loss + valid_action = ( + (~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1) + ) + return loss * valid_action + + +def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int: + seed = 0 + for idx in range(batch_size): + seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1) + return seed + + +def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None: + task = batch.get("task") + if task is None: + task = batch.get("observation.language") + if task is None: + return None + if isinstance(task, str): + return (task,) + if isinstance(task, (list, tuple)): + return tuple(str(item) for item in task) + return (str(task),) + + +def _extract_discrete_token_bins( + generated_ids: list[int], + start_token_id: int, + end_token_id: int, + token_id_to_bin: dict[int, int], +) -> list[int]: + start_idx = None + end_idx = None + for idx, token_id in enumerate(generated_ids): + if token_id == start_token_id: + start_idx = idx + break + if start_idx is not None: + for idx in range(start_idx + 1, len(generated_ids)): + if generated_ids[idx] == end_token_id: + end_idx = idx + break + span_start = 0 if start_idx is None else start_idx + 1 + span_end = len(generated_ids) if end_idx is None else end_idx + return [ + int(token_id_to_bin[token_id]) + for token_id in generated_ids[span_start:span_end] + if token_id in token_id_to_bin + ] + + +def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor: + if weights is None: + return values.mean() + weights = weights.to(device=values.device, dtype=values.dtype) + return torch.dot(values, weights) / weights.sum().clamp_min(1.0) + + +def _weighted_per_example( + values: Tensor, + weights: Tensor | None, + example_indices: Tensor, + batch_size: int, +) -> Tensor: + values = values.float() + if weights is None: + weights = torch.ones_like(values) + else: + weights = weights.to(device=values.device, dtype=values.dtype) + loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) + weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) + loss_sum.scatter_add_(0, example_indices, values * weights) + weight_sum.scatter_add_(0, example_indices, weights) + global_weight_sum = weight_sum.sum().clamp_min(1.0) + return loss_sum * float(batch_size) / global_weight_sum + + class MolmoAct2Policy(PreTrainedPolicy): + """MolmoAct2 policy wrapping the vendored HF model for LeRobot. + + Supports three training modes via ``config.action_mode``: + ``"continuous"`` (flow-matching only), ``"discrete"`` (autoregressive + token prediction only), or ``"both"`` (joint loss). At inference, + ``config.inference_action_mode`` selects which head generates actions. + """ + config_class = MolmoAct2Config name = "molmoact2" @@ -149,10 +518,10 @@ class MolmoAct2Policy(PreTrainedPolicy): **kwargs, ): super().__init__(config, *inputs, **kwargs) - self.config.apply_norm_tag_metadata() + _apply_norm_tag_metadata(self.config) self.config.validate_features() del inputs, kwargs, dataset_stats, dataset_meta - self._checkpoint_action_mode = self.config.saved_policy_action_mode() + self._checkpoint_action_mode = _saved_policy_action_mode(self.config) self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps) self._rollout_action_generator: torch.Generator | None = None self._rollout_task_key: tuple[Any, ...] | None = None @@ -160,7 +529,7 @@ class MolmoAct2Policy(PreTrainedPolicy): self.rtc_processor: RTCProcessor | None = None self.action_tokenizer: Any | None = None self._load_hf_model() - self.config.validate_inference_action_mode(self._checkpoint_action_mode) + _validate_inference_action_mode(self.config, self._checkpoint_action_mode) if self.config.enable_lora_vlm: self._apply_lora_adapters() self.init_rtc_processor() @@ -212,7 +581,8 @@ class MolmoAct2Policy(PreTrainedPolicy): "`policy.checkpoint_force_download=true` after the updated files are pushed." ) checkpoint_action_mode = str(self.model.config.action_mode) - self.config.validate_checkpoint_action_mode( + _validate_checkpoint_action_mode( + self.config, checkpoint_action_mode, has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)), ) @@ -226,6 +596,7 @@ class MolmoAct2Policy(PreTrainedPolicy): self.train(self.training) def reset(self) -> None: + """Clear the action queue and rollout generator between episodes.""" self._action_queue = deque(maxlen=self.config.n_action_steps) self._rollout_action_generator = None @@ -334,6 +705,7 @@ class MolmoAct2Policy(PreTrainedPolicy): param.requires_grad = False def get_optim_params(self) -> list[dict[str, Any]]: + """Return optimizer param groups with per-component learning rates.""" vit_params: list[Tensor] = [] connector_params: list[Tensor] = [] action_expert_params: list[Tensor] = [] @@ -419,33 +791,6 @@ class MolmoAct2Policy(PreTrainedPolicy): return int(value) raise RuntimeError("MolmoAct2 could not resolve an action generation horizon.") - @staticmethod - def _mask_discrete_action_spans( - *, - input_ids: Tensor, - mask: Tensor, - start_token_id: int | None, - end_token_id: int | None, - ) -> Tensor: - if start_token_id is None or end_token_id is None: - return mask - mask = mask.clone() - for batch_idx in range(input_ids.shape[0]): - row = input_ids[batch_idx] - starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist() - ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist() - end_ptr = 0 - for start in starts: - while end_ptr < len(ends) and ends[end_ptr] < start: - end_ptr += 1 - if end_ptr >= len(ends): - mask[batch_idx, start:] = False - break - end = int(ends[end_ptr]) - mask[batch_idx, start : end + 1] = False - end_ptr += 1 - return mask - def _encoder_attention_mask_for_action_expert( self, *, @@ -470,21 +815,13 @@ class MolmoAct2Policy(PreTrainedPolicy): eos_token_id = getattr(self.model.config, "eos_token_id", None) if eos_token_id is not None: mask &= input_ids != int(eos_token_id) - return self._mask_discrete_action_spans( + return _mask_discrete_action_spans( input_ids=input_ids, mask=mask, start_token_id=getattr(self.model.config, "action_start_token_id", None), end_token_id=getattr(self.model.config, "action_end_token_id", None), ) - @staticmethod - def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]: - attention_mask = model_inputs.get("attention_mask") - if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()): - model_inputs = dict(model_inputs) - model_inputs.pop("attention_mask", None) - return model_inputs - def _load_discrete_action_tokenizer(self) -> Any: if self.action_tokenizer is None: require_package("transformers", extra="molmoact2") @@ -498,27 +835,7 @@ class MolmoAct2Policy(PreTrainedPolicy): return self.action_tokenizer def _resolve_inference_action_mode(self, requested_mode: str | None) -> str: - return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode) - - @staticmethod - def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int: - seed = 0 - for idx in range(batch_size): - seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1) - return seed - - @staticmethod - def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None: - task = batch.get("task") - if task is None: - task = batch.get("observation.language") - if task is None: - return None - if isinstance(task, str): - return (task,) - if isinstance(task, (list, tuple)): - return tuple(str(item) for item in task) - return (str(task),) + return _resolve_inference_action_mode(self.config, requested_mode, self._checkpoint_action_mode) def _rollout_generator_for_inputs( self, @@ -532,7 +849,7 @@ class MolmoAct2Policy(PreTrainedPolicy): if self._rollout_action_generator is not None: return self._rollout_action_generator - task_signature = self._rollout_task_signature(batch) + task_signature = _rollout_task_signature(batch) if task_signature != self._rollout_task_key: self._rollout_task_key = task_signature self._rollout_index_for_task = 0 @@ -545,72 +862,10 @@ class MolmoAct2Policy(PreTrainedPolicy): device if device.type == "cuda" and torch.cuda.is_available() else torch.device("cpu") ) generator = torch.Generator(device=generator_device) - generator.manual_seed(self._combine_rollout_seeds(first_seed, batch_size)) + generator.manual_seed(_combine_rollout_seeds(first_seed, batch_size)) self._rollout_action_generator = generator return generator - @staticmethod - def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None: - if mask is None: - return None - return ( - mask.unsqueeze(1) - .expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1))) - .reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:]) - ) - - @staticmethod - def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None: - if action_dim_is_pad is None: - return None - mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool) - if mask.ndim == 1: - mask = mask.unsqueeze(0) - if mask.shape[-1] != target.shape[-1]: - raise ValueError( - f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}." - ) - if mask.shape[0] == 1 and target.shape[0] != 1: - mask = mask.expand(target.shape[0], -1) - if mask.shape[0] != target.shape[0]: - raise ValueError( - f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}." - ) - while mask.ndim < target.ndim: - mask = mask.unsqueeze(1) - return mask - - @classmethod - def _mask_action_dim_tensor(cls, tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: - if not cls._mask_enabled_static(action_dim_is_pad): - return tensor - valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad) - if valid_mask is None: - return tensor - return tensor.masked_fill(~valid_mask, 0) - - @staticmethod - def _mask_enabled_static(action_dim_is_pad: Tensor | None) -> bool: - return action_dim_is_pad is not None - - @classmethod - def _apply_action_dim_padding_mask(cls, loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor: - valid_mask = cls._action_dim_valid_mask(loss, action_dim_is_pad) - if valid_mask is None: - return loss - valid = valid_mask.to(dtype=loss.dtype) - denom = valid.sum(dim=-1).clamp_min(1.0) - return (loss * valid).sum(dim=-1) / denom - - @staticmethod - def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor: - if action_horizon_is_pad is None: - return loss - valid_action = ( - (~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1) - ) - return loss * valid_action - def _prepare_flow_matching_tensors( self, *, @@ -649,7 +904,7 @@ class MolmoAct2Policy(PreTrainedPolicy): ) if self.config.mask_action_dim_padding: - actions = self._mask_action_dim_tensor(actions, action_dim_is_pad) + actions = _mask_action_dim_tensor(actions, action_dim_is_pad) expected_noise_shape = (batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2]) if noise is None: @@ -661,7 +916,7 @@ class MolmoAct2Policy(PreTrainedPolicy): f"flow noise must have shape {expected_noise_shape}, got {tuple(noise.shape)}." ) if self.config.mask_action_dim_padding: - noise = self._mask_action_dim_tensor(noise, action_dim_is_pad) + noise = _mask_action_dim_tensor(noise, action_dim_is_pad) t_broadcast = timesteps.view(batch_size, num_flow_timesteps, 1, 1) actions_expanded = actions.unsqueeze(1).expand(-1, num_flow_timesteps, -1, -1) @@ -789,7 +1044,7 @@ class MolmoAct2Policy(PreTrainedPolicy): valid_action = None if action_attention_mask is not None: valid_action = action_attention_mask.to(device=device, dtype=actions.dtype).unsqueeze(-1) - valid_action = self._expand_mask(valid_action, num_flow_timesteps) + valid_action = _expand_mask(valid_action, num_flow_timesteps) rope_cache = None if len(action_expert.blocks) > 0 and action_expert.blocks[0].self_attn.rope is not None: @@ -804,14 +1059,14 @@ class MolmoAct2Policy(PreTrainedPolicy): batch_size, actions.dtype, ) - cross_mask = self._expand_mask(cross_mask, num_flow_timesteps) + cross_mask = _expand_mask(cross_mask, num_flow_timesteps) self_mask = action_expert._build_self_attention_mask( action_attention_mask, actions.shape[1], device, actions.dtype, ) - self_mask = self._expand_mask(self_mask, num_flow_timesteps) + self_mask = _expand_mask(self_mask, num_flow_timesteps) conditioning = self._action_time_conditioning(action_expert, timesteps_flat) action_hidden = action_expert.action_embed(xt_flat) @@ -871,8 +1126,8 @@ class MolmoAct2Policy(PreTrainedPolicy): if k_norm is not None: k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2) if num_flow_timesteps != 1: - k_ctx = self._expand_mask(k_ctx, num_flow_timesteps) - v_ctx = self._expand_mask(v_ctx, num_flow_timesteps) + k_ctx = _expand_mask(k_ctx, num_flow_timesteps) + v_ctx = _expand_mask(v_ctx, num_flow_timesteps) next_action_hidden = action_block( layer_action_hidden, @@ -912,9 +1167,9 @@ class MolmoAct2Policy(PreTrainedPolicy): ) loss = F.mse_loss(pred_velocity, target_velocity, reduction="none") - loss = self._apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad")) + loss = _apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad")) if self.config.mask_action_dim_padding: - loss = self._apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad")) + loss = _apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad")) loss = loss.reshape(batch_size, -1).mean(dim=1) if reduction == "mean": loss = loss.mean() @@ -933,32 +1188,6 @@ class MolmoAct2Policy(PreTrainedPolicy): example_weights[nonempty] = 2.0 / torch.sqrt(token_counts[nonempty]) return example_weights[:, None].expand_as(valid_positions)[valid_positions].to(dtype=torch.float32) - @staticmethod - def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor: - if weights is None: - return values.mean() - weights = weights.to(device=values.device, dtype=values.dtype) - return torch.dot(values, weights) / weights.sum().clamp_min(1.0) - - @staticmethod - def _weighted_per_example( - values: Tensor, - weights: Tensor | None, - example_indices: Tensor, - batch_size: int, - ) -> Tensor: - values = values.float() - if weights is None: - weights = torch.ones_like(values) - else: - weights = weights.to(device=values.device, dtype=values.dtype) - loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) - weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32) - loss_sum.scatter_add_(0, example_indices, values * weights) - weight_sum.scatter_add_(0, example_indices, weights) - global_weight_sum = weight_sum.sum().clamp_min(1.0) - return loss_sum * float(batch_size) / global_weight_sum - def _discrete_loss_from_backbone_outputs( self, batch: dict[str, Tensor], @@ -992,56 +1221,28 @@ class MolmoAct2Policy(PreTrainedPolicy): token_weights = self._discrete_token_weights(valid_positions) if reduction == "none": example_indices = valid_positions.nonzero(as_tuple=False)[:, 0].to(device=hidden_states.device) - ce_loss = self._weighted_per_example( + ce_loss = _weighted_per_example( token_ce_loss, token_weights, example_indices, int(labels.shape[0]), ) else: - ce_loss = self._weighted_mean(token_ce_loss, token_weights) + ce_loss = _weighted_mean(token_ce_loss, token_weights) if not self.config.softmax_auxiliary_loss: return ce_loss, None if reduction == "none": - z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_per_example( + z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_per_example( log_z.pow(2), token_weights, example_indices, int(labels.shape[0]), ) else: - z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_mean( - log_z.pow(2), token_weights - ) + z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_mean(log_z.pow(2), token_weights) return ce_loss, z_loss - @staticmethod - def _extract_discrete_token_bins( - generated_ids: list[int], - start_token_id: int, - end_token_id: int, - token_id_to_bin: dict[int, int], - ) -> list[int]: - start_idx = None - end_idx = None - for idx, token_id in enumerate(generated_ids): - if token_id == start_token_id: - start_idx = idx - break - if start_idx is not None: - for idx in range(start_idx + 1, len(generated_ids)): - if generated_ids[idx] == end_token_id: - end_idx = idx - break - span_start = 0 if start_idx is None else start_idx + 1 - span_end = len(generated_ids) if end_idx is None else end_idx - return [ - int(token_id_to_bin[token_id]) - for token_id in generated_ids[span_start:span_end] - if token_id in token_id_to_bin - ] - def _action_token_id_to_bin(self) -> dict[int, int]: method = getattr(self.model, "_action_token_id_to_bin", None) if callable(method): @@ -1179,7 +1380,7 @@ class MolmoAct2Policy(PreTrainedPolicy): chunks: list[Tensor] = [] for token_row in generated_token_ids: generated_ids = [int(token_id) for token_id in token_row.detach().cpu().tolist()] - discrete_token_ids = self._extract_discrete_token_bins( + discrete_token_ids = _extract_discrete_token_bins( generated_ids, int(self.model.config.action_start_token_id), int(self.model.config.action_end_token_id), @@ -1218,7 +1419,7 @@ class MolmoAct2Policy(PreTrainedPolicy): model_inputs: dict[str, Tensor], action_dim: int, ) -> Tensor: - model_inputs = self._drop_trivial_attention_mask(model_inputs) + model_inputs = _drop_trivial_attention_mask(model_inputs) max_steps = self._discrete_generation_max_steps() static_cache, attention_bias = self._make_discrete_ar_graph_decode_inputs( model_inputs, @@ -1294,7 +1495,7 @@ class MolmoAct2Policy(PreTrainedPolicy): generator=generator, ) if self.config.mask_action_dim_padding: - trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad) + trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad) action_context = action_expert.prepare_context( encoder_kv_states=encoder_kv_states, @@ -1327,7 +1528,7 @@ class MolmoAct2Policy(PreTrainedPolicy): modulation=step_modulation, ) if mask_enabled: - velocity = self._mask_action_dim_tensor(velocity, action_dim_is_pad) + velocity = _mask_action_dim_tensor(velocity, action_dim_is_pad) return velocity if self._rtc_enabled(): @@ -1352,7 +1553,7 @@ class MolmoAct2Policy(PreTrainedPolicy): trajectory = trajectory + dt * velocity if mask_enabled: - trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad) + trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad) if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=float(flow_timestep[0].item()), x_t=trajectory, v_t=velocity) @@ -1363,6 +1564,7 @@ class MolmoAct2Policy(PreTrainedPolicy): batch: dict[str, Tensor], reduction: str = "mean", ) -> tuple[Tensor, dict[str, Any]]: + """Compute training loss (flow-matching and/or discrete token loss).""" if reduction not in {"mean", "none"}: raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.") model_inputs = self._model_inputs(batch) @@ -1422,6 +1624,7 @@ class MolmoAct2Policy(PreTrainedPolicy): @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + """Generate an action chunk via continuous flow matching or discrete AR decoding.""" if "action_mode" in kwargs: raise TypeError( "MolmoAct2 predict_action_chunk got unexpected keyword argument 'action_mode'; " @@ -1476,6 +1679,7 @@ class MolmoAct2Policy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + """Pop one action step from the queue, regenerating the chunk when empty.""" if self._rtc_enabled(): raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk") self.eval() diff --git a/src/lerobot/policies/molmoact2/hf_model/__init__.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/__init__.py similarity index 94% rename from src/lerobot/policies/molmoact2/hf_model/__init__.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/__init__.py index 39b15cb3a..4436c9fda 100644 --- a/src/lerobot/policies/molmoact2/hf_model/__init__.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/__init__.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,5 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# ruff: noqa diff --git a/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/action_tokenizer.py similarity index 96% rename from src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/action_tokenizer.py index f7dacbce6..11a228731 100644 --- a/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/action_tokenizer.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,23 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa - import logging -import os from pathlib import Path from typing import ClassVar import numpy as np from tokenizers import ByteLevelBPETokenizer from tokenizers.trainers import BpeTrainer -from huggingface_hub import snapshot_download from transformers import PreTrainedTokenizerFast from transformers.processing_utils import ProcessorMixin +from ..modeling_molmoact2 import _hf_token -def _hf_token() -> str | None: - return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") +logger = logging.getLogger(__name__) def _resolve_tokenizer_location( @@ -42,6 +36,8 @@ def _resolve_tokenizer_location( local_path = Path(str(tokenizer_path)).expanduser() if local_path.exists(): return str(local_path) + from huggingface_hub import snapshot_download + return snapshot_download( repo_id=str(tokenizer_path), repo_type="model", @@ -134,9 +130,8 @@ class UniversalActionProcessor(ProcessorMixin): ), ( f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" ) - except Exception as e: - print(f"Error decoding tokens: {e}") - print(f"Tokens: {token}") + except Exception: + logger.warning("Error decoding tokens: %s", token, exc_info=True) decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho")) return np.stack(decoded_actions) diff --git a/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/configuration_molmoact2.py similarity index 99% rename from src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/configuration_molmoact2.py index 29da68c14..df5449bef 100644 --- a/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/configuration_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa """ MolmoAct2 configuration """ -from typing import Optional, Any +from typing import Any from transformers import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation diff --git a/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/image_processing_molmoact2.py similarity index 98% rename from src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/image_processing_molmoact2.py index a172c8477..acc709cb5 100644 --- a/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/image_processing_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,33 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa """Image processor class for MolmoAct2""" -from typing import Optional, Union -import numpy as np import einops +import numpy as np import torch import torchvision.transforms - +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_processing_utils import BaseImageProcessor, get_size_dict +from transformers.image_transforms import convert_to_rgb from transformers.image_utils import ( IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, make_flat_list_of_images, - valid_images, to_numpy_array, + valid_images, ) -from transformers.image_transforms import convert_to_rgb from transformers.processing_utils import ImagesKwargs -from transformers.image_processing_utils import BaseImageProcessor, get_size_dict -from transformers.utils import logging -from transformers.feature_extraction_utils import BatchFeature from transformers.utils import TensorType, logging - logger = logging.get_logger(__name__) @@ -73,8 +66,8 @@ def resize_image( )(image) resized = torch.clip(resized, 0.0, 1.0).to(dtype) else: - assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format( - image.dtype + assert image.dtype == torch.uint8, ( + f"SigLIP expects float images or uint8 images, but got {image.dtype}" ) in_min = 0.0 in_max = 255.0 @@ -96,7 +89,6 @@ def resize_image( def select_tiling(h, w, patch_size, max_num_crops): """Divide in image of size [w, h] in up to max_num_patches of size patch_size""" original_size = np.stack([h, w]) # [1, 2] - original_res = h * w tilings = [] for i in range(1, max_num_crops + 1): for j in range(1, max_num_crops + 1): @@ -406,13 +398,17 @@ class MolmoAct2ImageProcessor(BaseImageProcessor): image_std: float | list[float] | None = None, do_convert_rgb: bool = True, max_crops: int = 8, - overlap_margins: list[int] = [4, 4], + overlap_margins: list[int] | None = None, crop_mode: str = "overlap-and-resize-c2", patch_size: int = 14, - pooling_size: list[int] = [2, 2], + pooling_size: list[int] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) + if overlap_margins is None: + overlap_margins = [4, 4] + if pooling_size is None: + pooling_size = [2, 2] size = size if size is not None else {"height": 378, "width": 378} size = get_size_dict(size, default_to_square=True) self.size = size diff --git a/src/lerobot/policies/molmoact2/hf_model/inference.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/inference.py similarity index 99% rename from src/lerobot/policies/molmoact2/hf_model/inference.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/inference.py index 2c0243880..428800a8c 100644 --- a/src/lerobot/policies/molmoact2/hf_model/inference.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/inference.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa """Inference utilities for MolmoAct2""" -from dataclasses import dataclass -from typing import Any, Optional, Tuple from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import Any import torch -from torch.nn import functional as F +from torch.nn import functional as F # noqa: N812 from transformers.cache_utils import Cache from transformers.configuration_utils import PretrainedConfig @@ -679,7 +676,7 @@ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs: def _copy_context_(dst: Any, src: Any) -> None: - for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts): + for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts, strict=False): dst_k.copy_(src_k) dst_v.copy_(src_v) if src.cross_mask is not None: @@ -689,7 +686,7 @@ def _copy_context_(dst: Any, src: Any) -> None: if src.valid_action is not None: dst.valid_action.copy_(src.valid_action) if src.rope_cache is not None: - for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache): + for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache, strict=False): dst_tensor.copy_(src_tensor) diff --git a/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/modeling_molmoact2.py similarity index 99% rename from src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/modeling_molmoact2.py index 4c36b04c8..e2edbe68d 100644 --- a/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/modeling_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa """Modeling code for MolmoAct2""" +# ruff: noqa: N806 + import json import math import os import re +from collections.abc import Callable, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -from collections.abc import Callable, Mapping, Sequence +from typing import Any, Optional import numpy as np import torch import torch.utils.checkpoint from torch import nn -from torch.nn import functional as F +from torch.nn import functional as F # noqa: N812 from torch.nn.attention import SDPBackend, sdpa_kernel from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache @@ -647,7 +646,7 @@ class ActionExpert(nn.Module): f"got {len(encoder_kv_states)}." ) kv_contexts = [] - for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states): + for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states, strict=False): k_ctx = self._project_kv_tensor(k_in, self.context_k_proj) v_ctx = self._project_kv_tensor(v_in, self.context_v_proj) k_norm = block.cross_attn.k_norm @@ -732,7 +731,7 @@ class ActionExpert(nn.Module): timesteps: Sequence[torch.Tensor], ) -> Sequence[ActionExpertStepModulation]: cache = [] - for idx, step_t in enumerate(timesteps): + for _idx, step_t in enumerate(timesteps): conditioning = self._time_conditioning(step_t) block_modulations = [] for block in self.blocks: @@ -786,8 +785,8 @@ class ActionExpert(nn.Module): x = self.action_embed(actions) if context.valid_action is not None: x = x * context.valid_action - for idx, (block, kv_context, block_modulation) in enumerate( - zip(self.blocks, context.kv_contexts, block_modulations) + for _idx, (block, kv_context, block_modulation) in enumerate( + zip(self.blocks, context.kv_contexts, block_modulations, strict=False) ): x = block( x, @@ -2874,7 +2873,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): depth_mask=depth_mask, encoder_attention_mask=encoder_attention_mask, ) - for gate, source in zip(gate_head, sources) + for gate, source in zip(gate_head, sources, strict=False) ] return gates, depth_mask gate = self._depth_gate_from_source( @@ -4458,7 +4457,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi ```python >>> from PIL import Image >>> import requests - >>> from lerobot.policies.molmoact2.hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration + >>> from lerobot.policies.molmoact2.molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration >>> from lerobot.policies.molmoact2.processor_molmoact2 import _load_local_molmoact2_processor >>> model = MolmoAct2ForConditionalGeneration.from_pretrained("...") diff --git a/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/processing_molmoact2.py similarity index 95% rename from src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/processing_molmoact2.py index 7b8775faa..6a73d2465 100644 --- a/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/processing_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,45 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa """ Processor class for MolmoAct2. """ -from typing import Optional, Union -import dataclasses - import numpy as np - +from transformers import AutoTokenizer +from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput -from transformers.video_utils import VideoInput from transformers.processing_utils import ( - Unpack, ProcessingKwargs, ProcessorMixin, + Unpack, ) -from transformers.feature_extraction_utils import BatchFeature -from transformers.tokenization_utils_base import TextInput, PreTokenizedInput +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging +from transformers.video_utils import VideoInput -from transformers import AutoTokenizer -from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor -from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor - +from .image_processing_molmoact2 import MolmoAct2ImageProcessor, MolmoAct2ImagesKwargs +from .video_processing_molmoact2 import MolmoAct2VideoProcessor, MolmoAct2VideoProcessorKwargs logger = logging.get_logger(__name__) # Special tokens, these should be present in any tokenizer we use since the preprocessor uses them -IMAGE_PATCH_TOKEN = f"" # Where to insert high-res tokens -IMAGE_LOW_RES_TOKEN = f"" # Where to insert low-res tokens -IM_START_TOKEN = f"" -LOW_RES_IMAGE_START_TOKEN = f"" -FRAME_START_TOKEN = f"" -IM_END_TOKEN = f"" -FRAME_END_TOKEN = f"" -IM_COL_TOKEN = f"" +IMAGE_PATCH_TOKEN = "" # nosec B105 # Where to insert high-res tokens +IMAGE_LOW_RES_TOKEN = "" # nosec B105 # Where to insert low-res tokens +IM_START_TOKEN = "" # nosec B105 +LOW_RES_IMAGE_START_TOKEN = "" # nosec B105 +FRAME_START_TOKEN = "" # nosec B105 +IM_END_TOKEN = "" # nosec B105 +FRAME_END_TOKEN = "" # nosec B105 +IM_COL_TOKEN = "" # nosec B105 IMAGE_PROMPT = "<|image|>" VIDEO_PROMPT = "<|video|>" @@ -224,7 +216,7 @@ class MolmoAct2Processor(ProcessorMixin): input_ids = input_ids[None, :] attention_mask = attention_mask[None, :] - B, S = input_ids.shape + B, S = input_ids.shape # noqa: N806 # Handle zero-length sequence if S == 0: @@ -364,7 +356,7 @@ class MolmoAct2Processor(ProcessorMixin): assert num_videos in {0, 1}, "At most one video is supported for now" video_grids_i = video_grids[index : index + num_videos] metadata_i = video_metadata[index : index + num_videos] - for video_grid, metadata in zip(video_grids_i, metadata_i): + for video_grid, metadata in zip(video_grids_i, metadata_i, strict=False): video_string = self.get_video_string( video_grid, metadata.timestamps, diff --git a/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py b/src/lerobot/policies/molmoact2/molmoact2_hf_model/video_processing_molmoact2.py similarity index 98% rename from src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py rename to src/lerobot/policies/molmoact2/molmoact2_hf_model/video_processing_molmoact2.py index 644d5a691..bf4e44dde 100644 --- a/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/molmoact2_hf_model/video_processing_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,25 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa """Video processor class for MolmoAct2""" -from functools import partial import os import warnings +from collections.abc import Callable from contextlib import redirect_stdout +from functools import partial from io import BytesIO from urllib.parse import urlparse -from typing import Optional, Union -from collections.abc import Callable +import einops import numpy as np import requests -import einops import torch import torchvision.transforms - +from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ( IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, @@ -41,27 +37,24 @@ from transformers.image_utils import ( SizeDict, validate_kwargs, ) -from transformers.video_utils import ( - VideoInput, - is_valid_video, - make_batched_videos, - make_batched_metadata, - VideoMetadata, -) from transformers.processing_utils import Unpack, VideosKwargs -from transformers.video_processing_utils import BaseVideoProcessor -from transformers.utils import logging -from transformers.feature_extraction_utils import BatchFeature from transformers.utils import ( + TensorType, is_av_available, is_decord_available, is_torchcodec_available, is_yt_dlp_available, - TensorType, logging, to_numpy, ) - +from transformers.video_processing_utils import BaseVideoProcessor +from transformers.video_utils import ( + VideoInput, + VideoMetadata, + is_valid_video, + make_batched_metadata, + make_batched_videos, +) logger = logging.get_logger(__name__) @@ -102,8 +95,8 @@ def resize_image( )(image) resized = torch.clip(resized, 0.0, 1.0).to(dtype) else: - assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format( - image.dtype + assert image.dtype == torch.uint8, ( + f"SigLIP expects float images or uint8 images, but got {image.dtype}" ) in_min = 0.0 in_max = 255.0 @@ -548,9 +541,8 @@ def get_target_fps( step_size = max(int(video_fps / target_fps), 1) num_frames_sampled_at_fps = int(total_frames / step_size) if num_frames_sampled == 0: - if "uniform" in frame_sample_mode: - if num_frames_sampled_at_fps > max_frames: - break + if "uniform" in frame_sample_mode and num_frames_sampled_at_fps > max_frames: + break selected_target_fps = target_fps num_frames_sampled = num_frames_sampled_at_fps @@ -779,13 +771,15 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): elif is_torchcodec_available(): warnings.warn( "`decord` is not installed and cannot be used to decode the video by default. " - "Falling back to `torchcodec`." + "Falling back to `torchcodec`.", + stacklevel=2, ) backend = "torchcodec" else: warnings.warn( "`decord` is not installed and cannot be used to decode the video by default. " - "Falling back to `PyAV`." + "Falling back to `PyAV`.", + stacklevel=2, ) backend = "pyav" @@ -795,7 +789,8 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): *[ self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls - ] + ], + strict=False, ) ) else: @@ -821,7 +816,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): assert video_metadata[0].fps is not None, "FPS must be provided for video input" sampled_videos = [] sampled_metadata = [] - for video, metadata in zip(videos, video_metadata): + for video, metadata in zip(videos, video_metadata, strict=False): indices = sample_indices_fn(metadata=metadata) metadata.frames_indices = indices sampled_videos.append(video[indices]) @@ -985,11 +980,11 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): pixel_values_videos = np.concatenate(batch_crops, 0) video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0) - data = dict( - pixel_values_videos=pixel_values_videos, - video_token_pooling=video_token_pooling, - video_grids=video_grids, - ) + data = { + "pixel_values_videos": pixel_values_videos, + "video_token_pooling": video_token_pooling, + "video_grids": video_grids, + } return BatchFeature(data, tensor_type=return_tensors) diff --git a/src/lerobot/policies/molmoact2/processor_molmoact2.py b/src/lerobot/policies/molmoact2/processor_molmoact2.py index 6c7a3ed5c..1303e94a1 100644 --- a/src/lerobot/policies/molmoact2/processor_molmoact2.py +++ b/src/lerobot/policies/molmoact2/processor_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""MolmoAct2 pre/post processing pipeline. + +Builds the multimodal prompt (images, discretised state, task text), +tokenises it via the vendored MolmoAct2 processor, and handles quantile +normalisation with optional per-dimension gripper masking. +""" + from __future__ import annotations import json -import os +import logging +import math import re from contextlib import suppress from copy import deepcopy @@ -27,7 +33,6 @@ from typing import TYPE_CHECKING, Any import numpy as np import torch -from huggingface_hub import snapshot_download from torch import Tensor from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature @@ -54,14 +59,71 @@ from lerobot.utils.constants import ( ) from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package -from .configuration_molmoact2 import MolmoAct2Config, infer_molmoact2_max_sequence_length +from .configuration_molmoact2 import MolmoAct2Config +from .modeling_molmoact2 import _hf_token, _resolve_checkpoint_location + +logger = logging.getLogger(__name__) + +MOLMOACT2_DEFAULT_NUM_IMAGES = 2 +MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196 +MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80 +MOLMOACT2_TASK_TOKEN_BUDGET = 32 +MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32 +MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64 +MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4 +MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6 +MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95 + + +def _round_up(value: int, multiple: int) -> int: + return int(math.ceil(value / multiple) * multiple) + + +def infer_molmoact2_max_sequence_length( + *, + num_images: int, + state_dim: int, + action_dim: int, + action_horizon: int, + include_discrete_action: bool, +) -> int: + """Infer the padded text/image sequence cap from MolmoAct2's fixed token layout.""" + if num_images < 1: + num_images = MOLMOACT2_DEFAULT_NUM_IMAGES + if state_dim < 0: + state_dim = 0 + if action_dim < 1: + action_dim = 1 + if action_horizon < 1: + action_horizon = 1 + + image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE + prompt_tokens = ( + MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET + + MOLMOACT2_TASK_TOKEN_BUDGET + + state_dim + + MOLMOACT2_SEQUENCE_LENGTH_MARGIN + ) + action_tokens = 0 + if include_discrete_action: + action_tokens_per_step = max( + MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP, + math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM), + ) + action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step + + return _round_up( + image_tokens + prompt_tokens + action_tokens, + MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE, + ) + if TYPE_CHECKING or _transformers_available: from transformers import Qwen2Tokenizer - from .hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor - from .hf_model.processing_molmoact2 import MolmoAct2Processor - from .hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor + from .molmoact2_hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor + from .molmoact2_hf_model.processing_molmoact2 import MolmoAct2Processor + from .molmoact2_hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor else: Qwen2Tokenizer = None MolmoAct2ImageProcessor = None @@ -69,7 +131,7 @@ else: MolmoAct2VideoProcessor = None if TYPE_CHECKING or (_transformers_available and _scipy_available): - from .hf_model.action_tokenizer import UniversalActionProcessor + from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor else: UniversalActionProcessor = None @@ -97,32 +159,6 @@ _QUESTION_PREFIX_PATTERNS = tuple( ) -def _hf_token() -> str | None: - return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") - - -def _resolve_checkpoint_location( - checkpoint_path: str, - *, - revision: str | None = None, - force_download: bool = False, -) -> str: - checkpoint_path = str(checkpoint_path or "").strip() - if not checkpoint_path: - raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") - local_path = Path(checkpoint_path).expanduser() - if local_path.exists(): - return str(local_path) - return snapshot_download( - repo_id=checkpoint_path, - repo_type="model", - revision=revision, - force_download=force_download, - ignore_patterns=["*.py", "*.pyc", "__pycache__/*"], - token=_hf_token(), - ) - - def _load_hf_norm_stats_for_tag( checkpoint_path: str, *, diff --git a/tests/policies/molmoact2/test_molmoact2.py b/tests/policies/molmoact2/test_molmoact2.py index 3631bcc9b..5fba72913 100644 --- a/tests/policies/molmoact2/test_molmoact2.py +++ b/tests/policies/molmoact2/test_molmoact2.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,16 +33,16 @@ pytest.importorskip("scipy") from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies import get_policy_class, make_policy_config from lerobot.policies.molmoact2 import ( - configuration_molmoact2 as molmoact2_config, modeling_molmoact2 as molmoact2_modeling, processor_molmoact2 as molmoact2_processor, ) -from lerobot.policies.molmoact2.configuration_molmoact2 import ( - MolmoAct2Config, - MolmoAct2CosineDecayWithWarmupSchedulerConfig, - infer_molmoact2_max_sequence_length, +from lerobot.policies.molmoact2.configuration_molmoact2 import MolmoAct2Config +from lerobot.policies.molmoact2.modeling_molmoact2 import ( + MolmoAct2Policy, + _apply_action_chunk_padding_mask, + _apply_action_dim_padding_mask, + _combine_rollout_seeds, ) -from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy from lerobot.policies.molmoact2.processor_molmoact2 import ( MolmoAct2ClampNormalizedProcessorStep, MolmoAct2MaskedNormalizerProcessorStep, @@ -53,6 +51,7 @@ from lerobot.policies.molmoact2.processor_molmoact2 import ( _add_gripper_masks_to_stats, _build_discrete_state_string, _normalize_question_text, + infer_molmoact2_max_sequence_length, make_molmoact2_pre_post_processors, ) from lerobot.policies.rtc.configuration_rtc import RTCConfig @@ -71,34 +70,38 @@ def test_molmoact2_policy_registration(): assert cfg.per_episode_seed is False assert cfg.eval_seed is None assert cfg.normalize_language is True - assert cfg.get_scheduler_preset().num_decay_steps is None + assert cfg.get_scheduler_preset().num_decay_steps == 100_000 assert cfg.action_delta_indices == list(range(cfg.chunk_size)) assert get_policy_class("molmoact2") is MolmoAct2Policy def test_molmoact2_checkpoint_download_ignores_remote_python(monkeypatch): + import huggingface_hub + download_kwargs = {} def fake_snapshot_download(**kwargs): download_kwargs.update(kwargs) return "/tmp/downloaded-molmoact2" - monkeypatch.setattr(molmoact2_config, "snapshot_download", fake_snapshot_download) + monkeypatch.setattr(huggingface_hub, "snapshot_download", fake_snapshot_download) - checkpoint_location = molmoact2_config._resolve_checkpoint_location("allenai/MolmoAct2") + checkpoint_location = molmoact2_modeling._resolve_checkpoint_location("allenai/MolmoAct2") assert checkpoint_location == "/tmp/downloaded-molmoact2" assert download_kwargs["ignore_patterns"] == ["*.py", "*.pyc", "__pycache__/*"] -def test_molmoact2_scheduler_decay_steps_auto_match_training_steps(): +def test_molmoact2_scheduler_auto_scales_to_training_steps(): + from lerobot.optim import CosineDecayWithWarmupSchedulerConfig + param = torch.nn.Parameter(torch.ones(())) optimizer = torch.optim.AdamW([param], lr=0.001) - config = MolmoAct2CosineDecayWithWarmupSchedulerConfig( + config = CosineDecayWithWarmupSchedulerConfig( peak_lr=0.01, decay_lr=0.001, num_warmup_steps=10, - num_decay_steps=None, + num_decay_steps=100_000, ) scheduler = config.build(optimizer, num_training_steps=100) @@ -123,9 +126,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task(): batch_size=3, device=torch.device("cpu"), ) - expected_first = torch.Generator().manual_seed( - MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3) - ) + expected_first = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3)) assert torch.allclose(torch.rand(4, generator=first), torch.rand(4, generator=expected_first)) policy.reset() @@ -134,9 +135,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task(): batch_size=3, device=torch.device("cpu"), ) - expected_second = torch.Generator().manual_seed( - MolmoAct2Policy._combine_rollout_seeds(first_seed=1003, batch_size=3) - ) + expected_second = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1003, batch_size=3)) assert torch.allclose(torch.rand(4, generator=second), torch.rand(4, generator=expected_second)) policy.reset() @@ -145,9 +144,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task(): batch_size=3, device=torch.device("cpu"), ) - expected_new_task = torch.Generator().manual_seed( - MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3) - ) + expected_new_task = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3)) assert torch.allclose(torch.rand(4, generator=new_task), torch.rand(4, generator=expected_new_task)) @@ -537,36 +534,26 @@ def test_train_action_expert_only_requires_continuous_action_mode(): def test_molmoact2_sequence_length_is_inferred_from_fixed_token_budget(): - cfg = MolmoAct2Config( - action_mode="both", - chunk_size=10, - n_action_steps=10, - image_keys=["observation.images.image", "observation.images.wrist_image"], - input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))}, - output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,))}, - ) - - assert cfg.max_sequence_length is None - assert cfg.inferred_max_sequence_length() == 640 - assert cfg.inferred_max_sequence_length(include_discrete_action=False) == 576 assert ( infer_molmoact2_max_sequence_length( - num_images=2, - state_dim=8, - action_dim=7, - action_horizon=30, - include_discrete_action=True, + num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=True + ) + == 640 + ) + assert ( + infer_molmoact2_max_sequence_length( + num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=False + ) + == 576 + ) + assert ( + infer_molmoact2_max_sequence_length( + num_images=2, state_dim=8, action_dim=7, action_horizon=30, include_discrete_action=True ) == 768 ) -def test_molmoact2_sequence_length_override_is_preserved(): - cfg = MolmoAct2Config(max_sequence_length=1024) - - assert cfg.inferred_max_sequence_length(num_images=2, state_dim=8, action_dim=7) == 1024 - - def test_train_action_expert_only_freezes_non_action_expert_params(): class DummyBackbone(torch.nn.Module): def __init__(self): @@ -963,7 +950,7 @@ def test_action_dim_padding_loss_reduces_like_old_trainer(): ] ) - reduced = MolmoAct2Policy._apply_action_dim_padding_mask(loss, action_dim_is_pad) + reduced = _apply_action_dim_padding_mask(loss, action_dim_is_pad) expected = torch.stack( [ @@ -979,7 +966,7 @@ def test_action_chunk_padding_keeps_old_mean_denominator(): loss = torch.ones(1, 2, 4, 3) action_horizon_is_pad = torch.tensor([[False, False, True, True]]) - masked = MolmoAct2Policy._apply_action_chunk_padding_mask(loss, action_horizon_is_pad) + masked = _apply_action_chunk_padding_mask(loss, action_horizon_is_pad) assert masked.mean().item() == 0.5