From f395f36dec644b8a4f394ead78e6e365f16a6487 Mon Sep 17 00:00:00 2001 From: hq-fang <71356829+hq-fang@users.noreply.github.com> Date: Tue, 19 May 2026 22:02:30 +0000 Subject: [PATCH] move molmoact2 config logic into config --- .../molmoact2/configuration_molmoact2.py | 164 +++++++++++++++++ .../policies/molmoact2/modeling_molmoact2.py | 166 ++---------------- tests/policies/molmoact2/test_molmoact2.py | 12 +- 3 files changed, 181 insertions(+), 161 deletions(-) diff --git a/src/lerobot/policies/molmoact2/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/configuration_molmoact2.py index 7d07eb0db..409385ad2 100644 --- a/src/lerobot/policies/molmoact2/configuration_molmoact2.py +++ b/src/lerobot/policies/molmoact2/configuration_molmoact2.py @@ -16,10 +16,16 @@ from __future__ import annotations +import json import math +import os +from contextlib import suppress from dataclasses import dataclass, field +from pathlib import Path from typing import Any +from huggingface_hub import snapshot_download + from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig from lerobot.optim import ( AdamWConfig, @@ -42,6 +48,71 @@ MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6 MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95 +def _hf_token() -> str | None: + return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") + + +def _resolve_checkpoint_location( + checkpoint_path: str, + *, + revision: str | None = None, + force_download: bool = False, +) -> str: + checkpoint_path = str(checkpoint_path or "").strip() + if not checkpoint_path: + raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") + local_path = Path(checkpoint_path).expanduser() + if local_path.exists(): + return str(local_path) + return snapshot_download( + repo_id=checkpoint_path, + repo_type="model", + revision=revision, + force_download=force_download, + token=_hf_token(), + ) + + +def _load_hf_norm_metadata_for_tag( + checkpoint_path: str, + *, + revision: str | None, + force_download: bool, + norm_tag: str | None, +) -> dict[str, Any]: + norm_tag = str(norm_tag or "").strip() + if not norm_tag: + return {} + checkpoint_location = Path( + _resolve_checkpoint_location( + checkpoint_path, + revision=revision, + force_download=force_download, + ) + ) + norm_stats_filename = "norm_stats.json" + config_path = checkpoint_location / "config.json" + if config_path.exists(): + with suppress(OSError, json.JSONDecodeError): + norm_stats_filename = str( + json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename + ) + stats_path = checkpoint_location / norm_stats_filename + if not stats_path.exists(): + raise FileNotFoundError( + f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." + ) + payload = json.loads(stats_path.read_text()) + metadata_by_tag = payload.get("metadata_by_tag") + if not isinstance(metadata_by_tag, dict): + raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") + metadata = metadata_by_tag.get(norm_tag) + if not isinstance(metadata, dict): + available = sorted(str(tag) for tag in metadata_by_tag) + raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") + return metadata + + @LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup") @dataclass class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig): @@ -355,3 +426,96 @@ class MolmoAct2Config(PreTrainedConfig): shape=(self.expected_max_action_dim,), ) self.output_features[ACTION] = action_feature + + def apply_norm_tag_metadata(self) -> None: + if not str(self.norm_tag or "").strip(): + return + metadata = _load_hf_norm_metadata_for_tag( + self.checkpoint_path, + revision=self.checkpoint_revision, + force_download=bool(self.checkpoint_force_download), + norm_tag=self.norm_tag, + ) + if metadata.get("action_horizon") is not None: + self.chunk_size = int(metadata["action_horizon"]) + if metadata.get("n_action_steps") is not None: + self.n_action_steps = int(metadata["n_action_steps"]) + if not self.setup_type and metadata.get("setup_type") is not None: + self.setup_type = str(metadata["setup_type"]) + if not self.control_mode and metadata.get("control_mode") is not None: + self.control_mode = str(metadata["control_mode"]) + if not self.image_keys and isinstance(metadata.get("camera_keys"), list): + self.image_keys = [str(key) for key in metadata["camera_keys"]] + + def saved_policy_action_mode(self) -> str | None: + pretrained_path = getattr(self, "pretrained_path", None) + if pretrained_path is None: + return None + config_path = Path(pretrained_path) / "config.json" + if not config_path.exists(): + return None + try: + mode = json.loads(config_path.read_text()).get("action_mode") + except (OSError, json.JSONDecodeError): + return None + if mode in {"continuous", "discrete", "both"}: + return str(mode) + return None + + def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str: + return saved_policy_action_mode or self.action_mode + + def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None: + requested_mode = self.inference_action_mode + if requested_mode is None: + return + training_mode = self.training_action_mode(saved_policy_action_mode) + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run " + "continuous inference." + ) + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError( + "MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run " + "discrete inference. Train with action_mode='both' or action_mode='discrete' first." + ) + + def validate_checkpoint_action_mode( + self, + checkpoint_action_mode: str, + *, + has_action_expert: bool, + ) -> None: + if self.action_mode == "both" and checkpoint_action_mode != "both": + raise ValueError( + f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}." + ) + if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}: + raise ValueError( + f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, " + f"got {checkpoint_action_mode!r}." + ) + if self.action_mode in {"continuous", "both"} and not has_action_expert: + raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.") + + def resolve_inference_action_mode( + self, + requested_mode: str | None, + saved_policy_action_mode: str | None = None, + ) -> str: + training_mode = self.training_action_mode(saved_policy_action_mode) + if requested_mode is None: + requested_mode = self.inference_action_mode + if requested_mode is None: + raise ValueError( + "MolmoAct2 inference requires `inference_action_mode` to be set explicitly " + "to either 'continuous' or 'discrete'." + ) + if requested_mode not in {"continuous", "discrete"}: + raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.") + if requested_mode == "continuous" and training_mode == "discrete": + raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.") + if requested_mode == "discrete" and training_mode == "continuous": + raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.") + return requested_mode diff --git a/src/lerobot/policies/molmoact2/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/modeling_molmoact2.py index 87ef54513..3b225e19f 100644 --- a/src/lerobot/policies/molmoact2/modeling_molmoact2.py +++ b/src/lerobot/policies/molmoact2/modeling_molmoact2.py @@ -16,19 +16,15 @@ from __future__ import annotations -import json -import os import types from collections import deque -from contextlib import nullcontext, suppress -from pathlib import Path +from contextlib import nullcontext from typing import TYPE_CHECKING, Any import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torch.utils.checkpoint -from huggingface_hub import snapshot_download from torch import Tensor from torch.distributions import Beta @@ -37,7 +33,7 @@ from lerobot.utils.constants import ACTION from lerobot.utils.import_utils import _transformers_available, require_package from ..rtc.modeling_rtc import RTCProcessor -from .configuration_molmoact2 import MolmoAct2Config +from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location if TYPE_CHECKING or _transformers_available: from transformers import AutoModelForImageTextToText, AutoProcessor @@ -62,71 +58,6 @@ _MODEL_INPUT_KEYS = { } -def _hf_token() -> str | None: - return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN") - - -def _resolve_checkpoint_location( - checkpoint_path: str, - *, - revision: str | None = None, - force_download: bool = False, -) -> str: - checkpoint_path = str(checkpoint_path or "").strip() - if not checkpoint_path: - raise ValueError("MolmoAct2 policy requires `checkpoint_path`.") - local_path = Path(checkpoint_path).expanduser() - if local_path.exists(): - return str(local_path) - return snapshot_download( - repo_id=checkpoint_path, - repo_type="model", - revision=revision, - force_download=force_download, - token=_hf_token(), - ) - - -def _load_hf_norm_metadata_for_tag( - checkpoint_path: str, - *, - revision: str | None, - force_download: bool, - norm_tag: str | None, -) -> dict[str, Any]: - norm_tag = str(norm_tag or "").strip() - if not norm_tag: - return {} - checkpoint_location = Path( - _resolve_checkpoint_location( - checkpoint_path, - revision=revision, - force_download=force_download, - ) - ) - norm_stats_filename = "norm_stats.json" - config_path = checkpoint_location / "config.json" - if config_path.exists(): - with suppress(OSError, json.JSONDecodeError): - norm_stats_filename = str( - json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename - ) - stats_path = checkpoint_location / norm_stats_filename - if not stats_path.exists(): - raise FileNotFoundError( - f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}." - ) - payload = json.loads(stats_path.read_text()) - metadata_by_tag = payload.get("metadata_by_tag") - if not isinstance(metadata_by_tag, dict): - raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.") - metadata = metadata_by_tag.get(norm_tag) - if not isinstance(metadata, dict): - available = sorted(str(tag) for tag in metadata_by_tag) - raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.") - return metadata - - def _torch_dtype(dtype: str) -> torch.dtype: if dtype == "float32": return torch.float32 @@ -648,10 +579,10 @@ class MolmoAct2Policy(PreTrainedPolicy): **kwargs, ): super().__init__(config, *inputs, **kwargs) - self._checkpoint_action_mode = self._load_saved_policy_action_mode() - self._apply_norm_tag_metadata() + self.config.apply_norm_tag_metadata() self.config.validate_features() del inputs, kwargs, dataset_stats, dataset_meta + self._checkpoint_action_mode = self.config.saved_policy_action_mode() self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps) self._rollout_action_generator: torch.Generator | None = None self._rollout_task_key: tuple[Any, ...] | None = None @@ -659,65 +590,11 @@ class MolmoAct2Policy(PreTrainedPolicy): self.rtc_processor: RTCProcessor | None = None self.action_tokenizer: Any | None = None self._load_hf_model() - self._validate_inference_action_mode() + self.config.validate_inference_action_mode(self._checkpoint_action_mode) if self.config.enable_lora_vlm: self._apply_lora_adapters() self.init_rtc_processor() - def _load_saved_policy_action_mode(self) -> str | None: - pretrained_path = getattr(self.config, "pretrained_path", None) - if pretrained_path is None: - return None - config_path = Path(pretrained_path) / "config.json" - if not config_path.exists(): - return None - try: - mode = json.loads(config_path.read_text()).get("action_mode") - except (OSError, json.JSONDecodeError): - return None - if mode in {"continuous", "discrete", "both"}: - return str(mode) - return None - - def _training_action_mode(self) -> str: - return getattr(self, "_checkpoint_action_mode", None) or self.config.action_mode - - def _validate_inference_action_mode(self) -> None: - requested_mode = self.config.inference_action_mode - if requested_mode is None: - return - training_mode = self._training_action_mode() - if requested_mode == "continuous" and training_mode == "discrete": - raise ValueError( - "MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run " - "continuous inference." - ) - if requested_mode == "discrete" and training_mode == "continuous": - raise ValueError( - "MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run " - "discrete inference. Train with action_mode='both' or action_mode='discrete' first." - ) - - def _apply_norm_tag_metadata(self) -> None: - if not str(self.config.norm_tag or "").strip(): - return - metadata = _load_hf_norm_metadata_for_tag( - self.config.checkpoint_path, - revision=self.config.checkpoint_revision, - force_download=bool(self.config.checkpoint_force_download), - norm_tag=self.config.norm_tag, - ) - if metadata.get("action_horizon") is not None: - self.config.chunk_size = int(metadata["action_horizon"]) - if metadata.get("n_action_steps") is not None: - self.config.n_action_steps = int(metadata["n_action_steps"]) - if not self.config.setup_type and metadata.get("setup_type") is not None: - self.config.setup_type = str(metadata["setup_type"]) - if not self.config.control_mode and metadata.get("control_mode") is not None: - self.config.control_mode = str(metadata["control_mode"]) - if not self.config.image_keys and isinstance(metadata.get("camera_keys"), list): - self.config.image_keys = [str(key) for key in metadata["camera_keys"]] - def _load_hf_model(self) -> None: require_package("transformers", extra="molmoact2") @@ -756,19 +633,10 @@ class MolmoAct2Policy(PreTrainedPolicy): "`policy.checkpoint_force_download=true` after the updated files are pushed." ) checkpoint_action_mode = str(self.model.config.action_mode) - if self.config.action_mode == "both" and checkpoint_action_mode != "both": - raise ValueError( - f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}." - ) - if self.config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}: - raise ValueError( - f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, " - f"got {checkpoint_action_mode!r}." - ) - if self.config.action_mode in {"continuous", "both"} and not bool( - getattr(self.model.config, "add_action_expert", False) - ): - raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.") + self.config.validate_checkpoint_action_mode( + checkpoint_action_mode, + has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)), + ) if self.config.freeze_embedding: self._freeze_input_embeddings() @@ -1054,21 +922,7 @@ class MolmoAct2Policy(PreTrainedPolicy): return self.action_tokenizer def _resolve_inference_action_mode(self, requested_mode: str | None) -> str: - training_mode = self._training_action_mode() - if requested_mode is None: - requested_mode = self.config.inference_action_mode - if requested_mode is None: - raise ValueError( - "MolmoAct2 inference requires `inference_action_mode` to be set explicitly " - "to either 'continuous' or 'discrete'." - ) - if requested_mode not in {"continuous", "discrete"}: - raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.") - if requested_mode == "continuous" and training_mode == "discrete": - raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.") - if requested_mode == "discrete" and training_mode == "continuous": - raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.") - return requested_mode + return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode) @staticmethod def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int: diff --git a/tests/policies/molmoact2/test_molmoact2.py b/tests/policies/molmoact2/test_molmoact2.py index bb272a8a2..23925b074 100644 --- a/tests/policies/molmoact2/test_molmoact2.py +++ b/tests/policies/molmoact2/test_molmoact2.py @@ -653,7 +653,7 @@ def test_select_action_uses_single_full_batch_queue(): def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias(): policy = object.__new__(MolmoAct2Policy) torch.nn.Module.__init__(policy) - policy.config = SimpleNamespace(action_mode="both", inference_action_mode=None) + policy.config = MolmoAct2Config(action_mode="both", inference_action_mode=None) policy._checkpoint_action_mode = None with pytest.raises(ValueError, match="inference_action_mode.*explicitly"): @@ -1067,11 +1067,11 @@ def test_discrete_predict_action_chunk_uses_hf_cached_generation_path(): policy = object.__new__(MolmoAct2Policy) torch.nn.Module.__init__(policy) - policy.config = SimpleNamespace( + policy.config = MolmoAct2Config( action_mode="discrete", inference_action_mode="discrete", model_dtype="float32", - output_features={ACTION: SimpleNamespace(shape=(2,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))}, discrete_generation_max_steps=None, discrete_action_tokenizer="unused", trust_remote_code=True, @@ -1079,6 +1079,7 @@ def test_discrete_predict_action_chunk_uses_hf_cached_generation_path(): n_action_steps=1, rtc_config=None, ) + policy._checkpoint_action_mode = None policy.model = DummyModel() policy.action_tokenizer = _DummyActionTokenizer() @@ -1154,11 +1155,11 @@ def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled( policy = object.__new__(MolmoAct2Policy) torch.nn.Module.__init__(policy) - policy.config = SimpleNamespace( + policy.config = MolmoAct2Config( action_mode="discrete", inference_action_mode="discrete", model_dtype="float32", - output_features={ACTION: SimpleNamespace(shape=(2,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))}, discrete_generation_max_steps=None, discrete_action_tokenizer="unused", trust_remote_code=True, @@ -1167,6 +1168,7 @@ def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled( rtc_config=None, enable_inference_cuda_graph=True, ) + policy._checkpoint_action_mode = None policy.model = DummyModel() policy.action_tokenizer = _DummyActionTokenizer() torch.nn.Module.train(policy, False)