move molmoact2 config logic into config

This commit is contained in:
hq-fang
2026-05-19 22:02:30 +00:00
parent 738ba9272f
commit f395f36dec
3 changed files with 181 additions and 161 deletions
@@ -16,10 +16,16 @@
from __future__ import annotations
import json
import math
import os
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import (
AdamWConfig,
@@ -42,6 +48,71 @@ MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
@dataclass
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
@@ -355,3 +426,96 @@ class MolmoAct2Config(PreTrainedConfig):
shape=(self.expected_max_action_dim,),
)
self.output_features[ACTION] = action_feature
def apply_norm_tag_metadata(self) -> None:
if not str(self.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.checkpoint_path,
revision=self.checkpoint_revision,
force_download=bool(self.checkpoint_force_download),
norm_tag=self.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.n_action_steps = int(metadata["n_action_steps"])
if not self.setup_type and metadata.get("setup_type") is not None:
self.setup_type = str(metadata["setup_type"])
if not self.control_mode and metadata.get("control_mode") is not None:
self.control_mode = str(metadata["control_mode"])
if not self.image_keys and isinstance(metadata.get("camera_keys"), list):
self.image_keys = [str(key) for key in metadata["camera_keys"]]
def saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or self.action_mode
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
requested_mode = self.inference_action_mode
if requested_mode is None:
return
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def validate_checkpoint_action_mode(
self,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
if self.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def resolve_inference_action_mode(
self,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode is None:
requested_mode = self.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
@@ -16,19 +16,15 @@
from __future__ import annotations
import json
import os
import types
from collections import deque
from contextlib import nullcontext, suppress
from pathlib import Path
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torch.utils.checkpoint
from huggingface_hub import snapshot_download
from torch import Tensor
from torch.distributions import Beta
@@ -37,7 +33,7 @@ from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import _transformers_available, require_package
from ..rtc.modeling_rtc import RTCProcessor
from .configuration_molmoact2 import MolmoAct2Config
from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModelForImageTextToText, AutoProcessor
@@ -62,71 +58,6 @@ _MODEL_INPUT_KEYS = {
}
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
def _torch_dtype(dtype: str) -> torch.dtype:
if dtype == "float32":
return torch.float32
@@ -648,10 +579,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
**kwargs,
):
super().__init__(config, *inputs, **kwargs)
self._checkpoint_action_mode = self._load_saved_policy_action_mode()
self._apply_norm_tag_metadata()
self.config.apply_norm_tag_metadata()
self.config.validate_features()
del inputs, kwargs, dataset_stats, dataset_meta
self._checkpoint_action_mode = self.config.saved_policy_action_mode()
self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps)
self._rollout_action_generator: torch.Generator | None = None
self._rollout_task_key: tuple[Any, ...] | None = None
@@ -659,65 +590,11 @@ class MolmoAct2Policy(PreTrainedPolicy):
self.rtc_processor: RTCProcessor | None = None
self.action_tokenizer: Any | None = None
self._load_hf_model()
self._validate_inference_action_mode()
self.config.validate_inference_action_mode(self._checkpoint_action_mode)
if self.config.enable_lora_vlm:
self._apply_lora_adapters()
self.init_rtc_processor()
def _load_saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self.config, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def _training_action_mode(self) -> str:
return getattr(self, "_checkpoint_action_mode", None) or self.config.action_mode
def _validate_inference_action_mode(self) -> None:
requested_mode = self.config.inference_action_mode
if requested_mode is None:
return
training_mode = self._training_action_mode()
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def _apply_norm_tag_metadata(self) -> None:
if not str(self.config.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.config.checkpoint_path,
revision=self.config.checkpoint_revision,
force_download=bool(self.config.checkpoint_force_download),
norm_tag=self.config.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.config.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.config.n_action_steps = int(metadata["n_action_steps"])
if not self.config.setup_type and metadata.get("setup_type") is not None:
self.config.setup_type = str(metadata["setup_type"])
if not self.config.control_mode and metadata.get("control_mode") is not None:
self.config.control_mode = str(metadata["control_mode"])
if not self.config.image_keys and isinstance(metadata.get("camera_keys"), list):
self.config.image_keys = [str(key) for key in metadata["camera_keys"]]
def _load_hf_model(self) -> None:
require_package("transformers", extra="molmoact2")
@@ -756,19 +633,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
"`policy.checkpoint_force_download=true` after the updated files are pushed."
)
checkpoint_action_mode = str(self.model.config.action_mode)
if self.config.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.config.action_mode in {"continuous", "both"} and not bool(
getattr(self.model.config, "add_action_expert", False)
):
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
self.config.validate_checkpoint_action_mode(
checkpoint_action_mode,
has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)),
)
if self.config.freeze_embedding:
self._freeze_input_embeddings()
@@ -1054,21 +922,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
return self.action_tokenizer
def _resolve_inference_action_mode(self, requested_mode: str | None) -> str:
training_mode = self._training_action_mode()
if requested_mode is None:
requested_mode = self.config.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode)
@staticmethod
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
+7 -5
View File
@@ -653,7 +653,7 @@ def test_select_action_uses_single_full_batch_queue():
def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias():
policy = object.__new__(MolmoAct2Policy)
torch.nn.Module.__init__(policy)
policy.config = SimpleNamespace(action_mode="both", inference_action_mode=None)
policy.config = MolmoAct2Config(action_mode="both", inference_action_mode=None)
policy._checkpoint_action_mode = None
with pytest.raises(ValueError, match="inference_action_mode.*explicitly"):
@@ -1067,11 +1067,11 @@ def test_discrete_predict_action_chunk_uses_hf_cached_generation_path():
policy = object.__new__(MolmoAct2Policy)
torch.nn.Module.__init__(policy)
policy.config = SimpleNamespace(
policy.config = MolmoAct2Config(
action_mode="discrete",
inference_action_mode="discrete",
model_dtype="float32",
output_features={ACTION: SimpleNamespace(shape=(2,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))},
discrete_generation_max_steps=None,
discrete_action_tokenizer="unused",
trust_remote_code=True,
@@ -1079,6 +1079,7 @@ def test_discrete_predict_action_chunk_uses_hf_cached_generation_path():
n_action_steps=1,
rtc_config=None,
)
policy._checkpoint_action_mode = None
policy.model = DummyModel()
policy.action_tokenizer = _DummyActionTokenizer()
@@ -1154,11 +1155,11 @@ def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled(
policy = object.__new__(MolmoAct2Policy)
torch.nn.Module.__init__(policy)
policy.config = SimpleNamespace(
policy.config = MolmoAct2Config(
action_mode="discrete",
inference_action_mode="discrete",
model_dtype="float32",
output_features={ACTION: SimpleNamespace(shape=(2,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))},
discrete_generation_max_steps=None,
discrete_action_tokenizer="unused",
trust_remote_code=True,
@@ -1167,6 +1168,7 @@ def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled(
rtc_config=None,
enable_inference_cuda_graph=True,
)
policy._checkpoint_action_mode = None
policy.model = DummyModel()
policy.action_tokenizer = _DummyActionTokenizer()
torch.nn.Module.train(policy, False)