mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-11 21:49:47 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 17e217d175 |
@@ -17,7 +17,7 @@ the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
Install LeRobot with the MolmoAct2 optional dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[molmoact2]"
|
||||
uv sync --locked --extra molmoact2
|
||||
```
|
||||
|
||||
To run the models in this repository, you need an NVIDIA GPU. The measurements
|
||||
@@ -46,8 +46,8 @@ The repo has been tested with Ubuntu 22.04.
|
||||
|
||||
To use MolmoAct2 in a LeRobot training config, set:
|
||||
|
||||
```python
|
||||
policy.type=molmoact2
|
||||
```bash
|
||||
--policy.type=molmoact2
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
@@ -1 +1 @@
|
||||
../../../../docs/source/policy_molmoact2_README.md
|
||||
../../../../docs/source/molmoact2.mdx
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -16,16 +14,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import (
|
||||
AdamWConfig,
|
||||
@@ -37,146 +28,6 @@ from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from ..rtc.configuration_rtc import RTCConfig
|
||||
|
||||
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
|
||||
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
|
||||
MOLMOACT2_TASK_TOKEN_BUDGET = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
|
||||
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
|
||||
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _load_hf_norm_metadata_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None,
|
||||
force_download: bool,
|
||||
norm_tag: str | None,
|
||||
) -> dict[str, Any]:
|
||||
norm_tag = str(norm_tag or "").strip()
|
||||
if not norm_tag:
|
||||
return {}
|
||||
checkpoint_location = Path(
|
||||
_resolve_checkpoint_location(
|
||||
checkpoint_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
norm_stats_filename = "norm_stats.json"
|
||||
config_path = checkpoint_location / "config.json"
|
||||
if config_path.exists():
|
||||
with suppress(OSError, json.JSONDecodeError):
|
||||
norm_stats_filename = str(
|
||||
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||
)
|
||||
stats_path = checkpoint_location / norm_stats_filename
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||
)
|
||||
payload = json.loads(stats_path.read_text())
|
||||
metadata_by_tag = payload.get("metadata_by_tag")
|
||||
if not isinstance(metadata_by_tag, dict):
|
||||
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||
metadata = metadata_by_tag.get(norm_tag)
|
||||
if not isinstance(metadata, dict):
|
||||
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||
return metadata
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
|
||||
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
|
||||
|
||||
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
|
||||
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
|
||||
training steps"; build() is the first point where num_training_steps is known.
|
||||
"""
|
||||
|
||||
num_decay_steps: int | None
|
||||
|
||||
def build(self, optimizer, num_training_steps: int):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.peak_lr,
|
||||
decay_lr=self.decay_lr,
|
||||
num_warmup_steps=self.num_warmup_steps,
|
||||
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
|
||||
).build(optimizer, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def _round_up(value: int, multiple: int) -> int:
|
||||
return int(math.ceil(value / multiple) * multiple)
|
||||
|
||||
|
||||
def infer_molmoact2_max_sequence_length(
|
||||
*,
|
||||
num_images: int,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
action_horizon: int,
|
||||
include_discrete_action: bool,
|
||||
) -> int:
|
||||
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
|
||||
if num_images < 1:
|
||||
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim < 0:
|
||||
state_dim = 0
|
||||
if action_dim < 1:
|
||||
action_dim = 1
|
||||
if action_horizon < 1:
|
||||
action_horizon = 1
|
||||
|
||||
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
|
||||
prompt_tokens = (
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
|
||||
+ MOLMOACT2_TASK_TOKEN_BUDGET
|
||||
+ state_dim
|
||||
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
|
||||
)
|
||||
action_tokens = 0
|
||||
if include_discrete_action:
|
||||
action_tokens_per_step = max(
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
|
||||
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
|
||||
)
|
||||
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
|
||||
|
||||
return _round_up(
|
||||
image_tokens + prompt_tokens + action_tokens,
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
|
||||
)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("molmoact2")
|
||||
@dataclass
|
||||
@@ -255,7 +106,7 @@ class MolmoAct2Config(PreTrainedConfig):
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 200
|
||||
scheduler_decay_steps: int | None = None
|
||||
scheduler_decay_steps: int = 100_000
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
@@ -333,41 +184,6 @@ class MolmoAct2Config(PreTrainedConfig):
|
||||
if self.max_sequence_length is not None and self.max_sequence_length < 1:
|
||||
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
|
||||
|
||||
def inferred_max_sequence_length(
|
||||
self,
|
||||
*,
|
||||
num_images: int | None = None,
|
||||
state_dim: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
action_horizon: int | None = None,
|
||||
include_discrete_action: bool | None = None,
|
||||
) -> int:
|
||||
if self.max_sequence_length is not None:
|
||||
return int(self.max_sequence_length)
|
||||
|
||||
if num_images is None:
|
||||
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim is None:
|
||||
state_feature = self.robot_state_feature
|
||||
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
|
||||
if action_dim is None:
|
||||
action_feature = self.action_feature
|
||||
action_dim = (
|
||||
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
|
||||
)
|
||||
if action_horizon is None:
|
||||
action_horizon = self.chunk_size
|
||||
if include_discrete_action is None:
|
||||
include_discrete_action = self.action_mode in {"discrete", "both"}
|
||||
|
||||
return infer_molmoact2_max_sequence_length(
|
||||
num_images=int(num_images),
|
||||
state_dim=int(state_dim),
|
||||
action_dim=int(action_dim),
|
||||
action_horizon=int(action_horizon),
|
||||
include_discrete_action=bool(include_discrete_action),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -390,7 +206,7 @@ class MolmoAct2Config(PreTrainedConfig):
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
@@ -426,94 +242,3 @@ class MolmoAct2Config(PreTrainedConfig):
|
||||
shape=(self.expected_max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def apply_norm_tag_metadata(self) -> None:
|
||||
if not str(self.norm_tag or "").strip():
|
||||
return
|
||||
metadata = _load_hf_norm_metadata_for_tag(
|
||||
self.checkpoint_path,
|
||||
revision=self.checkpoint_revision,
|
||||
force_download=bool(self.checkpoint_force_download),
|
||||
norm_tag=self.norm_tag,
|
||||
)
|
||||
if metadata.get("action_horizon") is not None:
|
||||
self.chunk_size = int(metadata["action_horizon"])
|
||||
if metadata.get("n_action_steps") is not None:
|
||||
self.n_action_steps = int(metadata["n_action_steps"])
|
||||
if not self.setup_type and metadata.get("setup_type") is not None:
|
||||
self.setup_type = str(metadata["setup_type"])
|
||||
if not self.control_mode and metadata.get("control_mode") is not None:
|
||||
self.control_mode = str(metadata["control_mode"])
|
||||
|
||||
def saved_policy_action_mode(self) -> str | None:
|
||||
pretrained_path = getattr(self, "pretrained_path", None)
|
||||
if pretrained_path is None:
|
||||
return None
|
||||
config_path = Path(pretrained_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
if mode in {"continuous", "discrete", "both"}:
|
||||
return str(mode)
|
||||
return None
|
||||
|
||||
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
|
||||
return saved_policy_action_mode or self.action_mode
|
||||
|
||||
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
return
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||
"continuous inference."
|
||||
)
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||
)
|
||||
|
||||
def validate_checkpoint_action_mode(
|
||||
self,
|
||||
checkpoint_action_mode: str,
|
||||
*,
|
||||
has_action_expert: bool,
|
||||
) -> None:
|
||||
if self.action_mode == "both" and checkpoint_action_mode != "both":
|
||||
raise ValueError(
|
||||
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||
f"got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode in {"continuous", "both"} and not has_action_expert:
|
||||
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||
|
||||
def resolve_inference_action_mode(
|
||||
self,
|
||||
requested_mode: str | None,
|
||||
saved_policy_action_mode: str | None = None,
|
||||
) -> str:
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode is None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
raise ValueError(
|
||||
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||
"to either 'continuous' or 'discrete'."
|
||||
)
|
||||
if requested_mode not in {"continuous", "discrete"}:
|
||||
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||
return requested_mode
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,9 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MolmoAct2 policy for LeRobot.
|
||||
|
||||
MolmoAct2 is a VLM-based robotics policy from Allen AI that combines a
|
||||
Molmo vision-language backbone with a per-layer flow-matching action expert
|
||||
for continuous action generation, plus an optional discrete action token
|
||||
head. This module wraps the vendored HF model implementation
|
||||
(``molmoact2_hf_model/``) into the LeRobot ``PreTrainedPolicy`` interface.
|
||||
|
||||
Paper: https://allenai.org/blog/molmoact2
|
||||
Code: https://github.com/allenai/molmoact2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import types
|
||||
from collections import deque
|
||||
@@ -35,13 +46,58 @@ from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
|
||||
|
||||
from ..rtc.modeling_rtc import RTCProcessor
|
||||
from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location
|
||||
from .configuration_molmoact2 import MolmoAct2Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
"""Resolve a checkpoint path to a local directory, downloading from Hub if needed."""
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
from pathlib import Path
|
||||
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _torch_dtype(dtype: str) -> torch.dtype:
|
||||
"""Convert a dtype name string to a torch.dtype."""
|
||||
if dtype == "float32":
|
||||
return torch.float32
|
||||
if dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if dtype == "float16":
|
||||
return torch.float16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from .hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config
|
||||
from .hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
|
||||
from .molmoact2_hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config
|
||||
from .molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
|
||||
else:
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
@@ -49,7 +105,7 @@ else:
|
||||
MolmoAct2ForConditionalGeneration = None
|
||||
|
||||
if TYPE_CHECKING or (_transformers_available and _scipy_available):
|
||||
from .hf_model.action_tokenizer import UniversalActionProcessor
|
||||
from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor
|
||||
else:
|
||||
UniversalActionProcessor = None
|
||||
|
||||
@@ -70,6 +126,156 @@ _MODEL_INPUT_KEYS = {
|
||||
}
|
||||
|
||||
|
||||
def _load_hf_norm_metadata_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None,
|
||||
force_download: bool,
|
||||
norm_tag: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Read per-tag metadata from the checkpoint's ``norm_stats.json``."""
|
||||
norm_tag = str(norm_tag or "").strip()
|
||||
if not norm_tag:
|
||||
return {}
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
|
||||
checkpoint_location = Path(
|
||||
_resolve_checkpoint_location(
|
||||
checkpoint_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
norm_stats_filename = "norm_stats.json"
|
||||
config_path = checkpoint_location / "config.json"
|
||||
if config_path.exists():
|
||||
with suppress(OSError, json.JSONDecodeError):
|
||||
norm_stats_filename = str(
|
||||
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||
)
|
||||
stats_path = checkpoint_location / norm_stats_filename
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||
)
|
||||
payload = json.loads(stats_path.read_text())
|
||||
metadata_by_tag = payload.get("metadata_by_tag")
|
||||
if not isinstance(metadata_by_tag, dict):
|
||||
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||
metadata = metadata_by_tag.get(norm_tag)
|
||||
if not isinstance(metadata, dict):
|
||||
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||
return metadata
|
||||
|
||||
|
||||
def _apply_norm_tag_metadata(config: MolmoAct2Config) -> None:
|
||||
"""Populate config fields from the checkpoint's norm-tag metadata."""
|
||||
if not str(config.norm_tag or "").strip():
|
||||
return
|
||||
metadata = _load_hf_norm_metadata_for_tag(
|
||||
config.checkpoint_path,
|
||||
revision=config.checkpoint_revision,
|
||||
force_download=bool(config.checkpoint_force_download),
|
||||
norm_tag=config.norm_tag,
|
||||
)
|
||||
if metadata.get("action_horizon") is not None:
|
||||
config.chunk_size = int(metadata["action_horizon"])
|
||||
if metadata.get("n_action_steps") is not None:
|
||||
config.n_action_steps = int(metadata["n_action_steps"])
|
||||
if not config.setup_type and metadata.get("setup_type") is not None:
|
||||
config.setup_type = str(metadata["setup_type"])
|
||||
if not config.control_mode and metadata.get("control_mode") is not None:
|
||||
config.control_mode = str(metadata["control_mode"])
|
||||
|
||||
|
||||
def _saved_policy_action_mode(config: MolmoAct2Config) -> str | None:
|
||||
"""Read the action mode from a LeRobot-saved checkpoint's ``config.json``."""
|
||||
from pathlib import Path
|
||||
|
||||
pretrained_path = getattr(config, "pretrained_path", None)
|
||||
if pretrained_path is None:
|
||||
return None
|
||||
config_path = Path(pretrained_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
if mode in {"continuous", "discrete", "both"}:
|
||||
return str(mode)
|
||||
return None
|
||||
|
||||
|
||||
def _training_action_mode(config: MolmoAct2Config, saved_policy_action_mode: str | None = None) -> str:
|
||||
return saved_policy_action_mode or config.action_mode
|
||||
|
||||
|
||||
def _validate_inference_action_mode(
|
||||
config: MolmoAct2Config, saved_policy_action_mode: str | None = None
|
||||
) -> None:
|
||||
"""Check that the requested inference mode is compatible with the training mode."""
|
||||
requested_mode = config.inference_action_mode
|
||||
if requested_mode is None:
|
||||
return
|
||||
training_mode = _training_action_mode(config, saved_policy_action_mode)
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||
"continuous inference."
|
||||
)
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||
)
|
||||
|
||||
|
||||
def _validate_checkpoint_action_mode(
|
||||
config: MolmoAct2Config,
|
||||
checkpoint_action_mode: str,
|
||||
*,
|
||||
has_action_expert: bool,
|
||||
) -> None:
|
||||
"""Check that the checkpoint's action mode is compatible with the config."""
|
||||
if config.action_mode == "both" and checkpoint_action_mode != "both":
|
||||
raise ValueError(
|
||||
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||
f"got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if config.action_mode in {"continuous", "both"} and not has_action_expert:
|
||||
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||
|
||||
|
||||
def _resolve_inference_action_mode(
|
||||
config: MolmoAct2Config,
|
||||
requested_mode: str | None,
|
||||
saved_policy_action_mode: str | None = None,
|
||||
) -> str:
|
||||
"""Resolve the final inference action mode, validating compatibility."""
|
||||
training_mode = _training_action_mode(config, saved_policy_action_mode)
|
||||
if requested_mode is None:
|
||||
requested_mode = config.inference_action_mode
|
||||
if requested_mode is None:
|
||||
raise ValueError(
|
||||
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||
"to either 'continuous' or 'discrete'."
|
||||
)
|
||||
if requested_mode not in {"continuous", "discrete"}:
|
||||
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||
return requested_mode
|
||||
|
||||
|
||||
def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location: str) -> None:
|
||||
index_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_INDEX_NAME)
|
||||
single_file_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_NAME)
|
||||
@@ -103,16 +309,6 @@ def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location
|
||||
)
|
||||
|
||||
|
||||
def _torch_dtype(dtype: str) -> torch.dtype:
|
||||
if dtype == "float32":
|
||||
return torch.float32
|
||||
if dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if dtype == "float16":
|
||||
return torch.float16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def _sample_beta_timesteps(
|
||||
*,
|
||||
batch_size: int,
|
||||
@@ -136,7 +332,180 @@ def _sample_beta_timesteps(
|
||||
return time_offset + scale * samples
|
||||
|
||||
|
||||
def _mask_discrete_action_spans(
|
||||
*,
|
||||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
start_token_id: int | None,
|
||||
end_token_id: int | None,
|
||||
) -> Tensor:
|
||||
if start_token_id is None or end_token_id is None:
|
||||
return mask
|
||||
mask = mask.clone()
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
row = input_ids[batch_idx]
|
||||
starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist()
|
||||
ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist()
|
||||
end_ptr = 0
|
||||
for start in starts:
|
||||
while end_ptr < len(ends) and ends[end_ptr] < start:
|
||||
end_ptr += 1
|
||||
if end_ptr >= len(ends):
|
||||
mask[batch_idx, start:] = False
|
||||
break
|
||||
end = int(ends[end_ptr])
|
||||
mask[batch_idx, start : end + 1] = False
|
||||
end_ptr += 1
|
||||
return mask
|
||||
|
||||
|
||||
def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
attention_mask = model_inputs.get("attention_mask")
|
||||
if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()):
|
||||
model_inputs = dict(model_inputs)
|
||||
model_inputs.pop("attention_mask", None)
|
||||
return model_inputs
|
||||
|
||||
|
||||
def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None:
|
||||
if mask is None:
|
||||
return None
|
||||
return (
|
||||
mask.unsqueeze(1)
|
||||
.expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1)))
|
||||
.reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:])
|
||||
)
|
||||
|
||||
|
||||
def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None:
|
||||
if action_dim_is_pad is None:
|
||||
return None
|
||||
mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool)
|
||||
if mask.ndim == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[-1] != target.shape[-1]:
|
||||
raise ValueError(
|
||||
f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}."
|
||||
)
|
||||
if mask.shape[0] == 1 and target.shape[0] != 1:
|
||||
mask = mask.expand(target.shape[0], -1)
|
||||
if mask.shape[0] != target.shape[0]:
|
||||
raise ValueError(
|
||||
f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}."
|
||||
)
|
||||
while mask.ndim < target.ndim:
|
||||
mask = mask.unsqueeze(1)
|
||||
return mask
|
||||
|
||||
|
||||
def _mask_action_dim_tensor(tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
|
||||
if action_dim_is_pad is None:
|
||||
return tensor
|
||||
valid_mask = _action_dim_valid_mask(tensor, action_dim_is_pad)
|
||||
if valid_mask is None:
|
||||
return tensor
|
||||
return tensor.masked_fill(~valid_mask, 0)
|
||||
|
||||
|
||||
def _apply_action_dim_padding_mask(loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
|
||||
valid_mask = _action_dim_valid_mask(loss, action_dim_is_pad)
|
||||
if valid_mask is None:
|
||||
return loss
|
||||
valid = valid_mask.to(dtype=loss.dtype)
|
||||
denom = valid.sum(dim=-1).clamp_min(1.0)
|
||||
return (loss * valid).sum(dim=-1) / denom
|
||||
|
||||
|
||||
def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor:
|
||||
if action_horizon_is_pad is None:
|
||||
return loss
|
||||
valid_action = (
|
||||
(~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1)
|
||||
)
|
||||
return loss * valid_action
|
||||
|
||||
|
||||
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
|
||||
seed = 0
|
||||
for idx in range(batch_size):
|
||||
seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1)
|
||||
return seed
|
||||
|
||||
|
||||
def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None:
|
||||
task = batch.get("task")
|
||||
if task is None:
|
||||
task = batch.get("observation.language")
|
||||
if task is None:
|
||||
return None
|
||||
if isinstance(task, str):
|
||||
return (task,)
|
||||
if isinstance(task, (list, tuple)):
|
||||
return tuple(str(item) for item in task)
|
||||
return (str(task),)
|
||||
|
||||
|
||||
def _extract_discrete_token_bins(
|
||||
generated_ids: list[int],
|
||||
start_token_id: int,
|
||||
end_token_id: int,
|
||||
token_id_to_bin: dict[int, int],
|
||||
) -> list[int]:
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
for idx, token_id in enumerate(generated_ids):
|
||||
if token_id == start_token_id:
|
||||
start_idx = idx
|
||||
break
|
||||
if start_idx is not None:
|
||||
for idx in range(start_idx + 1, len(generated_ids)):
|
||||
if generated_ids[idx] == end_token_id:
|
||||
end_idx = idx
|
||||
break
|
||||
span_start = 0 if start_idx is None else start_idx + 1
|
||||
span_end = len(generated_ids) if end_idx is None else end_idx
|
||||
return [
|
||||
int(token_id_to_bin[token_id])
|
||||
for token_id in generated_ids[span_start:span_end]
|
||||
if token_id in token_id_to_bin
|
||||
]
|
||||
|
||||
|
||||
def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor:
|
||||
if weights is None:
|
||||
return values.mean()
|
||||
weights = weights.to(device=values.device, dtype=values.dtype)
|
||||
return torch.dot(values, weights) / weights.sum().clamp_min(1.0)
|
||||
|
||||
|
||||
def _weighted_per_example(
|
||||
values: Tensor,
|
||||
weights: Tensor | None,
|
||||
example_indices: Tensor,
|
||||
batch_size: int,
|
||||
) -> Tensor:
|
||||
values = values.float()
|
||||
if weights is None:
|
||||
weights = torch.ones_like(values)
|
||||
else:
|
||||
weights = weights.to(device=values.device, dtype=values.dtype)
|
||||
loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
|
||||
weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
|
||||
loss_sum.scatter_add_(0, example_indices, values * weights)
|
||||
weight_sum.scatter_add_(0, example_indices, weights)
|
||||
global_weight_sum = weight_sum.sum().clamp_min(1.0)
|
||||
return loss_sum * float(batch_size) / global_weight_sum
|
||||
|
||||
|
||||
class MolmoAct2Policy(PreTrainedPolicy):
|
||||
"""MolmoAct2 policy wrapping the vendored HF model for LeRobot.
|
||||
|
||||
Supports three training modes via ``config.action_mode``:
|
||||
``"continuous"`` (flow-matching only), ``"discrete"`` (autoregressive
|
||||
token prediction only), or ``"both"`` (joint loss). At inference,
|
||||
``config.inference_action_mode`` selects which head generates actions.
|
||||
"""
|
||||
|
||||
config_class = MolmoAct2Config
|
||||
name = "molmoact2"
|
||||
|
||||
@@ -149,10 +518,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.config.apply_norm_tag_metadata()
|
||||
_apply_norm_tag_metadata(self.config)
|
||||
self.config.validate_features()
|
||||
del inputs, kwargs, dataset_stats, dataset_meta
|
||||
self._checkpoint_action_mode = self.config.saved_policy_action_mode()
|
||||
self._checkpoint_action_mode = _saved_policy_action_mode(self.config)
|
||||
self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps)
|
||||
self._rollout_action_generator: torch.Generator | None = None
|
||||
self._rollout_task_key: tuple[Any, ...] | None = None
|
||||
@@ -160,7 +529,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
self.rtc_processor: RTCProcessor | None = None
|
||||
self.action_tokenizer: Any | None = None
|
||||
self._load_hf_model()
|
||||
self.config.validate_inference_action_mode(self._checkpoint_action_mode)
|
||||
_validate_inference_action_mode(self.config, self._checkpoint_action_mode)
|
||||
if self.config.enable_lora_vlm:
|
||||
self._apply_lora_adapters()
|
||||
self.init_rtc_processor()
|
||||
@@ -212,7 +581,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
"`policy.checkpoint_force_download=true` after the updated files are pushed."
|
||||
)
|
||||
checkpoint_action_mode = str(self.model.config.action_mode)
|
||||
self.config.validate_checkpoint_action_mode(
|
||||
_validate_checkpoint_action_mode(
|
||||
self.config,
|
||||
checkpoint_action_mode,
|
||||
has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)),
|
||||
)
|
||||
@@ -226,6 +596,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
self.train(self.training)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear the action queue and rollout generator between episodes."""
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
self._rollout_action_generator = None
|
||||
|
||||
@@ -334,6 +705,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
param.requires_grad = False
|
||||
|
||||
def get_optim_params(self) -> list[dict[str, Any]]:
|
||||
"""Return optimizer param groups with per-component learning rates."""
|
||||
vit_params: list[Tensor] = []
|
||||
connector_params: list[Tensor] = []
|
||||
action_expert_params: list[Tensor] = []
|
||||
@@ -419,33 +791,6 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
return int(value)
|
||||
raise RuntimeError("MolmoAct2 could not resolve an action generation horizon.")
|
||||
|
||||
@staticmethod
|
||||
def _mask_discrete_action_spans(
|
||||
*,
|
||||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
start_token_id: int | None,
|
||||
end_token_id: int | None,
|
||||
) -> Tensor:
|
||||
if start_token_id is None or end_token_id is None:
|
||||
return mask
|
||||
mask = mask.clone()
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
row = input_ids[batch_idx]
|
||||
starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist()
|
||||
ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist()
|
||||
end_ptr = 0
|
||||
for start in starts:
|
||||
while end_ptr < len(ends) and ends[end_ptr] < start:
|
||||
end_ptr += 1
|
||||
if end_ptr >= len(ends):
|
||||
mask[batch_idx, start:] = False
|
||||
break
|
||||
end = int(ends[end_ptr])
|
||||
mask[batch_idx, start : end + 1] = False
|
||||
end_ptr += 1
|
||||
return mask
|
||||
|
||||
def _encoder_attention_mask_for_action_expert(
|
||||
self,
|
||||
*,
|
||||
@@ -470,21 +815,13 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
eos_token_id = getattr(self.model.config, "eos_token_id", None)
|
||||
if eos_token_id is not None:
|
||||
mask &= input_ids != int(eos_token_id)
|
||||
return self._mask_discrete_action_spans(
|
||||
return _mask_discrete_action_spans(
|
||||
input_ids=input_ids,
|
||||
mask=mask,
|
||||
start_token_id=getattr(self.model.config, "action_start_token_id", None),
|
||||
end_token_id=getattr(self.model.config, "action_end_token_id", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
attention_mask = model_inputs.get("attention_mask")
|
||||
if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()):
|
||||
model_inputs = dict(model_inputs)
|
||||
model_inputs.pop("attention_mask", None)
|
||||
return model_inputs
|
||||
|
||||
def _load_discrete_action_tokenizer(self) -> Any:
|
||||
if self.action_tokenizer is None:
|
||||
require_package("transformers", extra="molmoact2")
|
||||
@@ -498,27 +835,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
return self.action_tokenizer
|
||||
|
||||
def _resolve_inference_action_mode(self, requested_mode: str | None) -> str:
|
||||
return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode)
|
||||
|
||||
@staticmethod
|
||||
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
|
||||
seed = 0
|
||||
for idx in range(batch_size):
|
||||
seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1)
|
||||
return seed
|
||||
|
||||
@staticmethod
|
||||
def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None:
|
||||
task = batch.get("task")
|
||||
if task is None:
|
||||
task = batch.get("observation.language")
|
||||
if task is None:
|
||||
return None
|
||||
if isinstance(task, str):
|
||||
return (task,)
|
||||
if isinstance(task, (list, tuple)):
|
||||
return tuple(str(item) for item in task)
|
||||
return (str(task),)
|
||||
return _resolve_inference_action_mode(self.config, requested_mode, self._checkpoint_action_mode)
|
||||
|
||||
def _rollout_generator_for_inputs(
|
||||
self,
|
||||
@@ -532,7 +849,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
if self._rollout_action_generator is not None:
|
||||
return self._rollout_action_generator
|
||||
|
||||
task_signature = self._rollout_task_signature(batch)
|
||||
task_signature = _rollout_task_signature(batch)
|
||||
if task_signature != self._rollout_task_key:
|
||||
self._rollout_task_key = task_signature
|
||||
self._rollout_index_for_task = 0
|
||||
@@ -545,72 +862,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
device if device.type == "cuda" and torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
generator = torch.Generator(device=generator_device)
|
||||
generator.manual_seed(self._combine_rollout_seeds(first_seed, batch_size))
|
||||
generator.manual_seed(_combine_rollout_seeds(first_seed, batch_size))
|
||||
self._rollout_action_generator = generator
|
||||
return generator
|
||||
|
||||
@staticmethod
|
||||
def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None:
|
||||
if mask is None:
|
||||
return None
|
||||
return (
|
||||
mask.unsqueeze(1)
|
||||
.expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1)))
|
||||
.reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None:
|
||||
if action_dim_is_pad is None:
|
||||
return None
|
||||
mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool)
|
||||
if mask.ndim == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[-1] != target.shape[-1]:
|
||||
raise ValueError(
|
||||
f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}."
|
||||
)
|
||||
if mask.shape[0] == 1 and target.shape[0] != 1:
|
||||
mask = mask.expand(target.shape[0], -1)
|
||||
if mask.shape[0] != target.shape[0]:
|
||||
raise ValueError(
|
||||
f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}."
|
||||
)
|
||||
while mask.ndim < target.ndim:
|
||||
mask = mask.unsqueeze(1)
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def _mask_action_dim_tensor(cls, tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
|
||||
if not cls._mask_enabled_static(action_dim_is_pad):
|
||||
return tensor
|
||||
valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad)
|
||||
if valid_mask is None:
|
||||
return tensor
|
||||
return tensor.masked_fill(~valid_mask, 0)
|
||||
|
||||
@staticmethod
|
||||
def _mask_enabled_static(action_dim_is_pad: Tensor | None) -> bool:
|
||||
return action_dim_is_pad is not None
|
||||
|
||||
@classmethod
|
||||
def _apply_action_dim_padding_mask(cls, loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
|
||||
valid_mask = cls._action_dim_valid_mask(loss, action_dim_is_pad)
|
||||
if valid_mask is None:
|
||||
return loss
|
||||
valid = valid_mask.to(dtype=loss.dtype)
|
||||
denom = valid.sum(dim=-1).clamp_min(1.0)
|
||||
return (loss * valid).sum(dim=-1) / denom
|
||||
|
||||
@staticmethod
|
||||
def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor:
|
||||
if action_horizon_is_pad is None:
|
||||
return loss
|
||||
valid_action = (
|
||||
(~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1)
|
||||
)
|
||||
return loss * valid_action
|
||||
|
||||
def _prepare_flow_matching_tensors(
|
||||
self,
|
||||
*,
|
||||
@@ -649,7 +904,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
if self.config.mask_action_dim_padding:
|
||||
actions = self._mask_action_dim_tensor(actions, action_dim_is_pad)
|
||||
actions = _mask_action_dim_tensor(actions, action_dim_is_pad)
|
||||
|
||||
expected_noise_shape = (batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2])
|
||||
if noise is None:
|
||||
@@ -661,7 +916,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
f"flow noise must have shape {expected_noise_shape}, got {tuple(noise.shape)}."
|
||||
)
|
||||
if self.config.mask_action_dim_padding:
|
||||
noise = self._mask_action_dim_tensor(noise, action_dim_is_pad)
|
||||
noise = _mask_action_dim_tensor(noise, action_dim_is_pad)
|
||||
|
||||
t_broadcast = timesteps.view(batch_size, num_flow_timesteps, 1, 1)
|
||||
actions_expanded = actions.unsqueeze(1).expand(-1, num_flow_timesteps, -1, -1)
|
||||
@@ -789,7 +1044,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
valid_action = None
|
||||
if action_attention_mask is not None:
|
||||
valid_action = action_attention_mask.to(device=device, dtype=actions.dtype).unsqueeze(-1)
|
||||
valid_action = self._expand_mask(valid_action, num_flow_timesteps)
|
||||
valid_action = _expand_mask(valid_action, num_flow_timesteps)
|
||||
|
||||
rope_cache = None
|
||||
if len(action_expert.blocks) > 0 and action_expert.blocks[0].self_attn.rope is not None:
|
||||
@@ -804,14 +1059,14 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
batch_size,
|
||||
actions.dtype,
|
||||
)
|
||||
cross_mask = self._expand_mask(cross_mask, num_flow_timesteps)
|
||||
cross_mask = _expand_mask(cross_mask, num_flow_timesteps)
|
||||
self_mask = action_expert._build_self_attention_mask(
|
||||
action_attention_mask,
|
||||
actions.shape[1],
|
||||
device,
|
||||
actions.dtype,
|
||||
)
|
||||
self_mask = self._expand_mask(self_mask, num_flow_timesteps)
|
||||
self_mask = _expand_mask(self_mask, num_flow_timesteps)
|
||||
|
||||
conditioning = self._action_time_conditioning(action_expert, timesteps_flat)
|
||||
action_hidden = action_expert.action_embed(xt_flat)
|
||||
@@ -871,8 +1126,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
if k_norm is not None:
|
||||
k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2)
|
||||
if num_flow_timesteps != 1:
|
||||
k_ctx = self._expand_mask(k_ctx, num_flow_timesteps)
|
||||
v_ctx = self._expand_mask(v_ctx, num_flow_timesteps)
|
||||
k_ctx = _expand_mask(k_ctx, num_flow_timesteps)
|
||||
v_ctx = _expand_mask(v_ctx, num_flow_timesteps)
|
||||
|
||||
next_action_hidden = action_block(
|
||||
layer_action_hidden,
|
||||
@@ -912,9 +1167,9 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
loss = F.mse_loss(pred_velocity, target_velocity, reduction="none")
|
||||
loss = self._apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad"))
|
||||
loss = _apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad"))
|
||||
if self.config.mask_action_dim_padding:
|
||||
loss = self._apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad"))
|
||||
loss = _apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad"))
|
||||
loss = loss.reshape(batch_size, -1).mean(dim=1)
|
||||
if reduction == "mean":
|
||||
loss = loss.mean()
|
||||
@@ -933,32 +1188,6 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
example_weights[nonempty] = 2.0 / torch.sqrt(token_counts[nonempty])
|
||||
return example_weights[:, None].expand_as(valid_positions)[valid_positions].to(dtype=torch.float32)
|
||||
|
||||
@staticmethod
|
||||
def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor:
|
||||
if weights is None:
|
||||
return values.mean()
|
||||
weights = weights.to(device=values.device, dtype=values.dtype)
|
||||
return torch.dot(values, weights) / weights.sum().clamp_min(1.0)
|
||||
|
||||
@staticmethod
|
||||
def _weighted_per_example(
|
||||
values: Tensor,
|
||||
weights: Tensor | None,
|
||||
example_indices: Tensor,
|
||||
batch_size: int,
|
||||
) -> Tensor:
|
||||
values = values.float()
|
||||
if weights is None:
|
||||
weights = torch.ones_like(values)
|
||||
else:
|
||||
weights = weights.to(device=values.device, dtype=values.dtype)
|
||||
loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
|
||||
weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
|
||||
loss_sum.scatter_add_(0, example_indices, values * weights)
|
||||
weight_sum.scatter_add_(0, example_indices, weights)
|
||||
global_weight_sum = weight_sum.sum().clamp_min(1.0)
|
||||
return loss_sum * float(batch_size) / global_weight_sum
|
||||
|
||||
def _discrete_loss_from_backbone_outputs(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
@@ -992,56 +1221,28 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
token_weights = self._discrete_token_weights(valid_positions)
|
||||
if reduction == "none":
|
||||
example_indices = valid_positions.nonzero(as_tuple=False)[:, 0].to(device=hidden_states.device)
|
||||
ce_loss = self._weighted_per_example(
|
||||
ce_loss = _weighted_per_example(
|
||||
token_ce_loss,
|
||||
token_weights,
|
||||
example_indices,
|
||||
int(labels.shape[0]),
|
||||
)
|
||||
else:
|
||||
ce_loss = self._weighted_mean(token_ce_loss, token_weights)
|
||||
ce_loss = _weighted_mean(token_ce_loss, token_weights)
|
||||
if not self.config.softmax_auxiliary_loss:
|
||||
return ce_loss, None
|
||||
|
||||
if reduction == "none":
|
||||
z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_per_example(
|
||||
z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_per_example(
|
||||
log_z.pow(2),
|
||||
token_weights,
|
||||
example_indices,
|
||||
int(labels.shape[0]),
|
||||
)
|
||||
else:
|
||||
z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_mean(
|
||||
log_z.pow(2), token_weights
|
||||
)
|
||||
z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_mean(log_z.pow(2), token_weights)
|
||||
return ce_loss, z_loss
|
||||
|
||||
@staticmethod
|
||||
def _extract_discrete_token_bins(
|
||||
generated_ids: list[int],
|
||||
start_token_id: int,
|
||||
end_token_id: int,
|
||||
token_id_to_bin: dict[int, int],
|
||||
) -> list[int]:
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
for idx, token_id in enumerate(generated_ids):
|
||||
if token_id == start_token_id:
|
||||
start_idx = idx
|
||||
break
|
||||
if start_idx is not None:
|
||||
for idx in range(start_idx + 1, len(generated_ids)):
|
||||
if generated_ids[idx] == end_token_id:
|
||||
end_idx = idx
|
||||
break
|
||||
span_start = 0 if start_idx is None else start_idx + 1
|
||||
span_end = len(generated_ids) if end_idx is None else end_idx
|
||||
return [
|
||||
int(token_id_to_bin[token_id])
|
||||
for token_id in generated_ids[span_start:span_end]
|
||||
if token_id in token_id_to_bin
|
||||
]
|
||||
|
||||
def _action_token_id_to_bin(self) -> dict[int, int]:
|
||||
method = getattr(self.model, "_action_token_id_to_bin", None)
|
||||
if callable(method):
|
||||
@@ -1179,7 +1380,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
chunks: list[Tensor] = []
|
||||
for token_row in generated_token_ids:
|
||||
generated_ids = [int(token_id) for token_id in token_row.detach().cpu().tolist()]
|
||||
discrete_token_ids = self._extract_discrete_token_bins(
|
||||
discrete_token_ids = _extract_discrete_token_bins(
|
||||
generated_ids,
|
||||
int(self.model.config.action_start_token_id),
|
||||
int(self.model.config.action_end_token_id),
|
||||
@@ -1218,7 +1419,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
model_inputs: dict[str, Tensor],
|
||||
action_dim: int,
|
||||
) -> Tensor:
|
||||
model_inputs = self._drop_trivial_attention_mask(model_inputs)
|
||||
model_inputs = _drop_trivial_attention_mask(model_inputs)
|
||||
max_steps = self._discrete_generation_max_steps()
|
||||
static_cache, attention_bias = self._make_discrete_ar_graph_decode_inputs(
|
||||
model_inputs,
|
||||
@@ -1294,7 +1495,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
generator=generator,
|
||||
)
|
||||
if self.config.mask_action_dim_padding:
|
||||
trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad)
|
||||
trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad)
|
||||
|
||||
action_context = action_expert.prepare_context(
|
||||
encoder_kv_states=encoder_kv_states,
|
||||
@@ -1327,7 +1528,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
modulation=step_modulation,
|
||||
)
|
||||
if mask_enabled:
|
||||
velocity = self._mask_action_dim_tensor(velocity, action_dim_is_pad)
|
||||
velocity = _mask_action_dim_tensor(velocity, action_dim_is_pad)
|
||||
return velocity
|
||||
|
||||
if self._rtc_enabled():
|
||||
@@ -1352,7 +1553,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
|
||||
trajectory = trajectory + dt * velocity
|
||||
if mask_enabled:
|
||||
trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad)
|
||||
trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=float(flow_timestep[0].item()), x_t=trajectory, v_t=velocity)
|
||||
|
||||
@@ -1363,6 +1564,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
batch: dict[str, Tensor],
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Compute training loss (flow-matching and/or discrete token loss)."""
|
||||
if reduction not in {"mean", "none"}:
|
||||
raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.")
|
||||
model_inputs = self._model_inputs(batch)
|
||||
@@ -1422,6 +1624,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
"""Generate an action chunk via continuous flow matching or discrete AR decoding."""
|
||||
if "action_mode" in kwargs:
|
||||
raise TypeError(
|
||||
"MolmoAct2 predict_action_chunk got unexpected keyword argument 'action_mode'; "
|
||||
@@ -1476,6 +1679,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
"""Pop one action step from the queue, regenerating the chunk when empty."""
|
||||
if self._rtc_enabled():
|
||||
raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk")
|
||||
self.eval()
|
||||
|
||||
-4
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,5 +11,3 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
+6
-11
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,23 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from ..modeling_molmoact2 import _hf_token
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_tokenizer_location(
|
||||
@@ -42,6 +36,8 @@ def _resolve_tokenizer_location(
|
||||
local_path = Path(str(tokenizer_path)).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return snapshot_download(
|
||||
repo_id=str(tokenizer_path),
|
||||
repo_type="model",
|
||||
@@ -134,9 +130,8 @@ class UniversalActionProcessor(ProcessorMixin):
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
except Exception:
|
||||
logger.warning("Error decoding tokens: %s", token, exc_info=True)
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
+1
-4
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,13 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
MolmoAct2 configuration
|
||||
"""
|
||||
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
+13
-17
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,33 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Image processor class for MolmoAct2"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import numpy as np
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
||||
from transformers.image_transforms import convert_to_rgb
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_flat_list_of_images,
|
||||
valid_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from transformers.image_transforms import convert_to_rgb
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,8 +66,8 @@ def resize_image(
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
assert image.dtype == torch.uint8, (
|
||||
f"SigLIP expects float images or uint8 images, but got {image.dtype}"
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
@@ -96,7 +89,6 @@ def resize_image(
|
||||
def select_tiling(h, w, patch_size, max_num_crops):
|
||||
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
||||
original_size = np.stack([h, w]) # [1, 2]
|
||||
original_res = h * w
|
||||
tilings = []
|
||||
for i in range(1, max_num_crops + 1):
|
||||
for j in range(1, max_num_crops + 1):
|
||||
@@ -406,13 +398,17 @@ class MolmoAct2ImageProcessor(BaseImageProcessor):
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool = True,
|
||||
max_crops: int = 8,
|
||||
overlap_margins: list[int] = [4, 4],
|
||||
overlap_margins: list[int] | None = None,
|
||||
crop_mode: str = "overlap-and-resize-c2",
|
||||
patch_size: int = 14,
|
||||
pooling_size: list[int] = [2, 2],
|
||||
pooling_size: list[int] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if overlap_margins is None:
|
||||
overlap_margins = [4, 4]
|
||||
if pooling_size is None:
|
||||
pooling_size = [2, 2]
|
||||
size = size if size is not None else {"height": 378, "width": 378}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
self.size = size
|
||||
+5
-8
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,16 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Inference utilities for MolmoAct2"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import functional as F # noqa: N812
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
@@ -679,7 +676,7 @@ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
||||
|
||||
|
||||
def _copy_context_(dst: Any, src: Any) -> None:
|
||||
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
||||
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts, strict=False):
|
||||
dst_k.copy_(src_k)
|
||||
dst_v.copy_(src_v)
|
||||
if src.cross_mask is not None:
|
||||
@@ -689,7 +686,7 @@ def _copy_context_(dst: Any, src: Any) -> None:
|
||||
if src.valid_action is not None:
|
||||
dst.valid_action.copy_(src.valid_action)
|
||||
if src.rope_cache is not None:
|
||||
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
||||
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache, strict=False):
|
||||
dst_tensor.copy_(src_tensor)
|
||||
|
||||
|
||||
+11
-12
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,24 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Modeling code for MolmoAct2"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import functional as F # noqa: N812
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
@@ -647,7 +646,7 @@ class ActionExpert(nn.Module):
|
||||
f"got {len(encoder_kv_states)}."
|
||||
)
|
||||
kv_contexts = []
|
||||
for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states):
|
||||
for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states, strict=False):
|
||||
k_ctx = self._project_kv_tensor(k_in, self.context_k_proj)
|
||||
v_ctx = self._project_kv_tensor(v_in, self.context_v_proj)
|
||||
k_norm = block.cross_attn.k_norm
|
||||
@@ -732,7 +731,7 @@ class ActionExpert(nn.Module):
|
||||
timesteps: Sequence[torch.Tensor],
|
||||
) -> Sequence[ActionExpertStepModulation]:
|
||||
cache = []
|
||||
for idx, step_t in enumerate(timesteps):
|
||||
for _idx, step_t in enumerate(timesteps):
|
||||
conditioning = self._time_conditioning(step_t)
|
||||
block_modulations = []
|
||||
for block in self.blocks:
|
||||
@@ -786,8 +785,8 @@ class ActionExpert(nn.Module):
|
||||
x = self.action_embed(actions)
|
||||
if context.valid_action is not None:
|
||||
x = x * context.valid_action
|
||||
for idx, (block, kv_context, block_modulation) in enumerate(
|
||||
zip(self.blocks, context.kv_contexts, block_modulations)
|
||||
for _idx, (block, kv_context, block_modulation) in enumerate(
|
||||
zip(self.blocks, context.kv_contexts, block_modulations, strict=False)
|
||||
):
|
||||
x = block(
|
||||
x,
|
||||
@@ -2874,7 +2873,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel):
|
||||
depth_mask=depth_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
for gate, source in zip(gate_head, sources)
|
||||
for gate, source in zip(gate_head, sources, strict=False)
|
||||
]
|
||||
return gates, depth_mask
|
||||
gate = self._depth_gate_from_source(
|
||||
@@ -4458,7 +4457,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from lerobot.policies.molmoact2.hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
|
||||
>>> from lerobot.policies.molmoact2.molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
|
||||
>>> from lerobot.policies.molmoact2.processor_molmoact2 import _load_local_molmoact2_processor
|
||||
|
||||
>>> model = MolmoAct2ForConditionalGeneration.from_pretrained("...")
|
||||
+17
-25
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,45 +12,39 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
Processor class for MolmoAct2.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import dataclasses
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.video_utils import VideoInput
|
||||
from transformers.processing_utils import (
|
||||
Unpack,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
)
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
||||
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
||||
|
||||
from .image_processing_molmoact2 import MolmoAct2ImageProcessor, MolmoAct2ImagesKwargs
|
||||
from .video_processing_molmoact2 import MolmoAct2VideoProcessor, MolmoAct2VideoProcessorKwargs
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
||||
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
||||
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
||||
IM_START_TOKEN = f"<im_start>"
|
||||
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
||||
FRAME_START_TOKEN = f"<frame_start>"
|
||||
IM_END_TOKEN = f"<im_end>"
|
||||
FRAME_END_TOKEN = f"<frame_end>"
|
||||
IM_COL_TOKEN = f"<im_col>"
|
||||
IMAGE_PATCH_TOKEN = "<im_patch>" # nosec B105 # Where to insert high-res tokens
|
||||
IMAGE_LOW_RES_TOKEN = "<im_low>" # nosec B105 # Where to insert low-res tokens
|
||||
IM_START_TOKEN = "<im_start>" # nosec B105
|
||||
LOW_RES_IMAGE_START_TOKEN = "<low_res_im_start>" # nosec B105
|
||||
FRAME_START_TOKEN = "<frame_start>" # nosec B105
|
||||
IM_END_TOKEN = "<im_end>" # nosec B105
|
||||
FRAME_END_TOKEN = "<frame_end>" # nosec B105
|
||||
IM_COL_TOKEN = "<im_col>" # nosec B105
|
||||
IMAGE_PROMPT = "<|image|>"
|
||||
VIDEO_PROMPT = "<|video|>"
|
||||
|
||||
@@ -224,7 +216,7 @@ class MolmoAct2Processor(ProcessorMixin):
|
||||
input_ids = input_ids[None, :]
|
||||
attention_mask = attention_mask[None, :]
|
||||
|
||||
B, S = input_ids.shape
|
||||
B, S = input_ids.shape # noqa: N806
|
||||
|
||||
# Handle zero-length sequence
|
||||
if S == 0:
|
||||
@@ -364,7 +356,7 @@ class MolmoAct2Processor(ProcessorMixin):
|
||||
assert num_videos in {0, 1}, "At most one video is supported for now"
|
||||
video_grids_i = video_grids[index : index + num_videos]
|
||||
metadata_i = video_metadata[index : index + num_videos]
|
||||
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
||||
for video_grid, metadata in zip(video_grids_i, metadata_i, strict=False):
|
||||
video_string = self.get_video_string(
|
||||
video_grid,
|
||||
metadata.timestamps,
|
||||
+29
-34
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,25 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Video processor class for MolmoAct2"""
|
||||
|
||||
from functools import partial
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import redirect_stdout
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import requests
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
@@ -41,27 +37,24 @@ from transformers.image_utils import (
|
||||
SizeDict,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.video_utils import (
|
||||
VideoInput,
|
||||
is_valid_video,
|
||||
make_batched_videos,
|
||||
make_batched_metadata,
|
||||
VideoMetadata,
|
||||
)
|
||||
from transformers.processing_utils import Unpack, VideosKwargs
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
is_av_available,
|
||||
is_decord_available,
|
||||
is_torchcodec_available,
|
||||
is_yt_dlp_available,
|
||||
TensorType,
|
||||
logging,
|
||||
to_numpy,
|
||||
)
|
||||
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from transformers.video_utils import (
|
||||
VideoInput,
|
||||
VideoMetadata,
|
||||
is_valid_video,
|
||||
make_batched_metadata,
|
||||
make_batched_videos,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -102,8 +95,8 @@ def resize_image(
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
assert image.dtype == torch.uint8, (
|
||||
f"SigLIP expects float images or uint8 images, but got {image.dtype}"
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
@@ -548,9 +541,8 @@ def get_target_fps(
|
||||
step_size = max(int(video_fps / target_fps), 1)
|
||||
num_frames_sampled_at_fps = int(total_frames / step_size)
|
||||
if num_frames_sampled == 0:
|
||||
if "uniform" in frame_sample_mode:
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
break
|
||||
if "uniform" in frame_sample_mode and num_frames_sampled_at_fps > max_frames:
|
||||
break
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
|
||||
@@ -779,13 +771,15 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
||||
elif is_torchcodec_available():
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `torchcodec`."
|
||||
"Falling back to `torchcodec`.",
|
||||
stacklevel=2,
|
||||
)
|
||||
backend = "torchcodec"
|
||||
else:
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `PyAV`."
|
||||
"Falling back to `PyAV`.",
|
||||
stacklevel=2,
|
||||
)
|
||||
backend = "pyav"
|
||||
|
||||
@@ -795,7 +789,8 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
||||
*[
|
||||
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
|
||||
for x in video_url_or_urls
|
||||
]
|
||||
],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -821,7 +816,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
||||
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
|
||||
sampled_videos = []
|
||||
sampled_metadata = []
|
||||
for video, metadata in zip(videos, video_metadata):
|
||||
for video, metadata in zip(videos, video_metadata, strict=False):
|
||||
indices = sample_indices_fn(metadata=metadata)
|
||||
metadata.frames_indices = indices
|
||||
sampled_videos.append(video[indices])
|
||||
@@ -985,11 +980,11 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
||||
pixel_values_videos = np.concatenate(batch_crops, 0)
|
||||
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||
|
||||
data = dict(
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
video_token_pooling=video_token_pooling,
|
||||
video_grids=video_grids,
|
||||
)
|
||||
data = {
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"video_token_pooling": video_token_pooling,
|
||||
"video_grids": video_grids,
|
||||
}
|
||||
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,10 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MolmoAct2 pre/post processing pipeline.
|
||||
|
||||
Builds the multimodal prompt (images, discretised state, task text),
|
||||
tokenises it via the vendored MolmoAct2 processor, and handles quantile
|
||||
normalisation with optional per-dimension gripper masking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
@@ -27,7 +33,6 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
@@ -54,14 +59,71 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
|
||||
|
||||
from .configuration_molmoact2 import MolmoAct2Config, infer_molmoact2_max_sequence_length
|
||||
from .configuration_molmoact2 import MolmoAct2Config
|
||||
from .modeling_molmoact2 import _hf_token, _resolve_checkpoint_location
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
|
||||
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
|
||||
MOLMOACT2_TASK_TOKEN_BUDGET = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
|
||||
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
|
||||
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
|
||||
|
||||
|
||||
def _round_up(value: int, multiple: int) -> int:
|
||||
return int(math.ceil(value / multiple) * multiple)
|
||||
|
||||
|
||||
def infer_molmoact2_max_sequence_length(
|
||||
*,
|
||||
num_images: int,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
action_horizon: int,
|
||||
include_discrete_action: bool,
|
||||
) -> int:
|
||||
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
|
||||
if num_images < 1:
|
||||
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim < 0:
|
||||
state_dim = 0
|
||||
if action_dim < 1:
|
||||
action_dim = 1
|
||||
if action_horizon < 1:
|
||||
action_horizon = 1
|
||||
|
||||
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
|
||||
prompt_tokens = (
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
|
||||
+ MOLMOACT2_TASK_TOKEN_BUDGET
|
||||
+ state_dim
|
||||
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
|
||||
)
|
||||
action_tokens = 0
|
||||
if include_discrete_action:
|
||||
action_tokens_per_step = max(
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
|
||||
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
|
||||
)
|
||||
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
|
||||
|
||||
return _round_up(
|
||||
image_tokens + prompt_tokens + action_tokens,
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
from .hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor
|
||||
from .hf_model.processing_molmoact2 import MolmoAct2Processor
|
||||
from .hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor
|
||||
from .molmoact2_hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor
|
||||
from .molmoact2_hf_model.processing_molmoact2 import MolmoAct2Processor
|
||||
from .molmoact2_hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor
|
||||
else:
|
||||
Qwen2Tokenizer = None
|
||||
MolmoAct2ImageProcessor = None
|
||||
@@ -69,7 +131,7 @@ else:
|
||||
MolmoAct2VideoProcessor = None
|
||||
|
||||
if TYPE_CHECKING or (_transformers_available and _scipy_available):
|
||||
from .hf_model.action_tokenizer import UniversalActionProcessor
|
||||
from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor
|
||||
else:
|
||||
UniversalActionProcessor = None
|
||||
|
||||
@@ -97,32 +159,6 @@ _QUESTION_PREFIX_PATTERNS = tuple(
|
||||
)
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _load_hf_norm_stats_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -35,16 +33,16 @@ pytest.importorskip("scipy")
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies import get_policy_class, make_policy_config
|
||||
from lerobot.policies.molmoact2 import (
|
||||
configuration_molmoact2 as molmoact2_config,
|
||||
modeling_molmoact2 as molmoact2_modeling,
|
||||
processor_molmoact2 as molmoact2_processor,
|
||||
)
|
||||
from lerobot.policies.molmoact2.configuration_molmoact2 import (
|
||||
MolmoAct2Config,
|
||||
MolmoAct2CosineDecayWithWarmupSchedulerConfig,
|
||||
infer_molmoact2_max_sequence_length,
|
||||
from lerobot.policies.molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from lerobot.policies.molmoact2.modeling_molmoact2 import (
|
||||
MolmoAct2Policy,
|
||||
_apply_action_chunk_padding_mask,
|
||||
_apply_action_dim_padding_mask,
|
||||
_combine_rollout_seeds,
|
||||
)
|
||||
from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
from lerobot.policies.molmoact2.processor_molmoact2 import (
|
||||
MolmoAct2ClampNormalizedProcessorStep,
|
||||
MolmoAct2MaskedNormalizerProcessorStep,
|
||||
@@ -53,6 +51,7 @@ from lerobot.policies.molmoact2.processor_molmoact2 import (
|
||||
_add_gripper_masks_to_stats,
|
||||
_build_discrete_state_string,
|
||||
_normalize_question_text,
|
||||
infer_molmoact2_max_sequence_length,
|
||||
make_molmoact2_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
@@ -71,34 +70,38 @@ def test_molmoact2_policy_registration():
|
||||
assert cfg.per_episode_seed is False
|
||||
assert cfg.eval_seed is None
|
||||
assert cfg.normalize_language is True
|
||||
assert cfg.get_scheduler_preset().num_decay_steps is None
|
||||
assert cfg.get_scheduler_preset().num_decay_steps == 100_000
|
||||
assert cfg.action_delta_indices == list(range(cfg.chunk_size))
|
||||
assert get_policy_class("molmoact2") is MolmoAct2Policy
|
||||
|
||||
|
||||
def test_molmoact2_checkpoint_download_ignores_remote_python(monkeypatch):
|
||||
import huggingface_hub
|
||||
|
||||
download_kwargs = {}
|
||||
|
||||
def fake_snapshot_download(**kwargs):
|
||||
download_kwargs.update(kwargs)
|
||||
return "/tmp/downloaded-molmoact2"
|
||||
|
||||
monkeypatch.setattr(molmoact2_config, "snapshot_download", fake_snapshot_download)
|
||||
monkeypatch.setattr(huggingface_hub, "snapshot_download", fake_snapshot_download)
|
||||
|
||||
checkpoint_location = molmoact2_config._resolve_checkpoint_location("allenai/MolmoAct2")
|
||||
checkpoint_location = molmoact2_modeling._resolve_checkpoint_location("allenai/MolmoAct2")
|
||||
|
||||
assert checkpoint_location == "/tmp/downloaded-molmoact2"
|
||||
assert download_kwargs["ignore_patterns"] == ["*.py", "*.pyc", "__pycache__/*"]
|
||||
|
||||
|
||||
def test_molmoact2_scheduler_decay_steps_auto_match_training_steps():
|
||||
def test_molmoact2_scheduler_auto_scales_to_training_steps():
|
||||
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
param = torch.nn.Parameter(torch.ones(()))
|
||||
optimizer = torch.optim.AdamW([param], lr=0.001)
|
||||
config = MolmoAct2CosineDecayWithWarmupSchedulerConfig(
|
||||
config = CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=0.01,
|
||||
decay_lr=0.001,
|
||||
num_warmup_steps=10,
|
||||
num_decay_steps=None,
|
||||
num_decay_steps=100_000,
|
||||
)
|
||||
|
||||
scheduler = config.build(optimizer, num_training_steps=100)
|
||||
@@ -123,9 +126,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
|
||||
batch_size=3,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
expected_first = torch.Generator().manual_seed(
|
||||
MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3)
|
||||
)
|
||||
expected_first = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3))
|
||||
assert torch.allclose(torch.rand(4, generator=first), torch.rand(4, generator=expected_first))
|
||||
|
||||
policy.reset()
|
||||
@@ -134,9 +135,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
|
||||
batch_size=3,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
expected_second = torch.Generator().manual_seed(
|
||||
MolmoAct2Policy._combine_rollout_seeds(first_seed=1003, batch_size=3)
|
||||
)
|
||||
expected_second = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1003, batch_size=3))
|
||||
assert torch.allclose(torch.rand(4, generator=second), torch.rand(4, generator=expected_second))
|
||||
|
||||
policy.reset()
|
||||
@@ -145,9 +144,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
|
||||
batch_size=3,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
expected_new_task = torch.Generator().manual_seed(
|
||||
MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3)
|
||||
)
|
||||
expected_new_task = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3))
|
||||
assert torch.allclose(torch.rand(4, generator=new_task), torch.rand(4, generator=expected_new_task))
|
||||
|
||||
|
||||
@@ -537,36 +534,26 @@ def test_train_action_expert_only_requires_continuous_action_mode():
|
||||
|
||||
|
||||
def test_molmoact2_sequence_length_is_inferred_from_fixed_token_budget():
|
||||
cfg = MolmoAct2Config(
|
||||
action_mode="both",
|
||||
chunk_size=10,
|
||||
n_action_steps=10,
|
||||
image_keys=["observation.images.image", "observation.images.wrist_image"],
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,))},
|
||||
)
|
||||
|
||||
assert cfg.max_sequence_length is None
|
||||
assert cfg.inferred_max_sequence_length() == 640
|
||||
assert cfg.inferred_max_sequence_length(include_discrete_action=False) == 576
|
||||
assert (
|
||||
infer_molmoact2_max_sequence_length(
|
||||
num_images=2,
|
||||
state_dim=8,
|
||||
action_dim=7,
|
||||
action_horizon=30,
|
||||
include_discrete_action=True,
|
||||
num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=True
|
||||
)
|
||||
== 640
|
||||
)
|
||||
assert (
|
||||
infer_molmoact2_max_sequence_length(
|
||||
num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=False
|
||||
)
|
||||
== 576
|
||||
)
|
||||
assert (
|
||||
infer_molmoact2_max_sequence_length(
|
||||
num_images=2, state_dim=8, action_dim=7, action_horizon=30, include_discrete_action=True
|
||||
)
|
||||
== 768
|
||||
)
|
||||
|
||||
|
||||
def test_molmoact2_sequence_length_override_is_preserved():
|
||||
cfg = MolmoAct2Config(max_sequence_length=1024)
|
||||
|
||||
assert cfg.inferred_max_sequence_length(num_images=2, state_dim=8, action_dim=7) == 1024
|
||||
|
||||
|
||||
def test_train_action_expert_only_freezes_non_action_expert_params():
|
||||
class DummyBackbone(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@@ -963,7 +950,7 @@ def test_action_dim_padding_loss_reduces_like_old_trainer():
|
||||
]
|
||||
)
|
||||
|
||||
reduced = MolmoAct2Policy._apply_action_dim_padding_mask(loss, action_dim_is_pad)
|
||||
reduced = _apply_action_dim_padding_mask(loss, action_dim_is_pad)
|
||||
|
||||
expected = torch.stack(
|
||||
[
|
||||
@@ -979,7 +966,7 @@ def test_action_chunk_padding_keeps_old_mean_denominator():
|
||||
loss = torch.ones(1, 2, 4, 3)
|
||||
action_horizon_is_pad = torch.tensor([[False, False, True, True]])
|
||||
|
||||
masked = MolmoAct2Policy._apply_action_chunk_padding_mask(loss, action_horizon_is_pad)
|
||||
masked = _apply_action_chunk_padding_mask(loss, action_horizon_is_pad)
|
||||
|
||||
assert masked.mean().item() == 0.5
|
||||
|
||||
|
||||
Reference in New Issue
Block a user