mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
move molmoact2 config logic into config
This commit is contained in:
@@ -16,10 +16,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import (
|
||||
AdamWConfig,
|
||||
@@ -42,6 +48,71 @@ MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
|
||||
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _load_hf_norm_metadata_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None,
|
||||
force_download: bool,
|
||||
norm_tag: str | None,
|
||||
) -> dict[str, Any]:
|
||||
norm_tag = str(norm_tag or "").strip()
|
||||
if not norm_tag:
|
||||
return {}
|
||||
checkpoint_location = Path(
|
||||
_resolve_checkpoint_location(
|
||||
checkpoint_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
norm_stats_filename = "norm_stats.json"
|
||||
config_path = checkpoint_location / "config.json"
|
||||
if config_path.exists():
|
||||
with suppress(OSError, json.JSONDecodeError):
|
||||
norm_stats_filename = str(
|
||||
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||
)
|
||||
stats_path = checkpoint_location / norm_stats_filename
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||
)
|
||||
payload = json.loads(stats_path.read_text())
|
||||
metadata_by_tag = payload.get("metadata_by_tag")
|
||||
if not isinstance(metadata_by_tag, dict):
|
||||
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||
metadata = metadata_by_tag.get(norm_tag)
|
||||
if not isinstance(metadata, dict):
|
||||
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||
return metadata
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
|
||||
@@ -355,3 +426,96 @@ class MolmoAct2Config(PreTrainedConfig):
|
||||
shape=(self.expected_max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def apply_norm_tag_metadata(self) -> None:
|
||||
if not str(self.norm_tag or "").strip():
|
||||
return
|
||||
metadata = _load_hf_norm_metadata_for_tag(
|
||||
self.checkpoint_path,
|
||||
revision=self.checkpoint_revision,
|
||||
force_download=bool(self.checkpoint_force_download),
|
||||
norm_tag=self.norm_tag,
|
||||
)
|
||||
if metadata.get("action_horizon") is not None:
|
||||
self.chunk_size = int(metadata["action_horizon"])
|
||||
if metadata.get("n_action_steps") is not None:
|
||||
self.n_action_steps = int(metadata["n_action_steps"])
|
||||
if not self.setup_type and metadata.get("setup_type") is not None:
|
||||
self.setup_type = str(metadata["setup_type"])
|
||||
if not self.control_mode and metadata.get("control_mode") is not None:
|
||||
self.control_mode = str(metadata["control_mode"])
|
||||
if not self.image_keys and isinstance(metadata.get("camera_keys"), list):
|
||||
self.image_keys = [str(key) for key in metadata["camera_keys"]]
|
||||
|
||||
def saved_policy_action_mode(self) -> str | None:
|
||||
pretrained_path = getattr(self, "pretrained_path", None)
|
||||
if pretrained_path is None:
|
||||
return None
|
||||
config_path = Path(pretrained_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
if mode in {"continuous", "discrete", "both"}:
|
||||
return str(mode)
|
||||
return None
|
||||
|
||||
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
|
||||
return saved_policy_action_mode or self.action_mode
|
||||
|
||||
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
return
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||
"continuous inference."
|
||||
)
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||
)
|
||||
|
||||
def validate_checkpoint_action_mode(
|
||||
self,
|
||||
checkpoint_action_mode: str,
|
||||
*,
|
||||
has_action_expert: bool,
|
||||
) -> None:
|
||||
if self.action_mode == "both" and checkpoint_action_mode != "both":
|
||||
raise ValueError(
|
||||
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||
f"got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode in {"continuous", "both"} and not has_action_expert:
|
||||
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||
|
||||
def resolve_inference_action_mode(
|
||||
self,
|
||||
requested_mode: str | None,
|
||||
saved_policy_action_mode: str | None = None,
|
||||
) -> str:
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode is None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
raise ValueError(
|
||||
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||
"to either 'continuous' or 'discrete'."
|
||||
)
|
||||
if requested_mode not in {"continuous", "discrete"}:
|
||||
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||
return requested_mode
|
||||
|
||||
@@ -16,19 +16,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import types
|
||||
from collections import deque
|
||||
from contextlib import nullcontext, suppress
|
||||
from pathlib import Path
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch import Tensor
|
||||
from torch.distributions import Beta
|
||||
|
||||
@@ -37,7 +33,7 @@ from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from ..rtc.modeling_rtc import RTCProcessor
|
||||
from .configuration_molmoact2 import MolmoAct2Config
|
||||
from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
@@ -62,71 +58,6 @@ _MODEL_INPUT_KEYS = {
|
||||
}
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _load_hf_norm_metadata_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None,
|
||||
force_download: bool,
|
||||
norm_tag: str | None,
|
||||
) -> dict[str, Any]:
|
||||
norm_tag = str(norm_tag or "").strip()
|
||||
if not norm_tag:
|
||||
return {}
|
||||
checkpoint_location = Path(
|
||||
_resolve_checkpoint_location(
|
||||
checkpoint_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
norm_stats_filename = "norm_stats.json"
|
||||
config_path = checkpoint_location / "config.json"
|
||||
if config_path.exists():
|
||||
with suppress(OSError, json.JSONDecodeError):
|
||||
norm_stats_filename = str(
|
||||
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||
)
|
||||
stats_path = checkpoint_location / norm_stats_filename
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||
)
|
||||
payload = json.loads(stats_path.read_text())
|
||||
metadata_by_tag = payload.get("metadata_by_tag")
|
||||
if not isinstance(metadata_by_tag, dict):
|
||||
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||
metadata = metadata_by_tag.get(norm_tag)
|
||||
if not isinstance(metadata, dict):
|
||||
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||
return metadata
|
||||
|
||||
|
||||
def _torch_dtype(dtype: str) -> torch.dtype:
|
||||
if dtype == "float32":
|
||||
return torch.float32
|
||||
@@ -648,10 +579,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self._checkpoint_action_mode = self._load_saved_policy_action_mode()
|
||||
self._apply_norm_tag_metadata()
|
||||
self.config.apply_norm_tag_metadata()
|
||||
self.config.validate_features()
|
||||
del inputs, kwargs, dataset_stats, dataset_meta
|
||||
self._checkpoint_action_mode = self.config.saved_policy_action_mode()
|
||||
self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps)
|
||||
self._rollout_action_generator: torch.Generator | None = None
|
||||
self._rollout_task_key: tuple[Any, ...] | None = None
|
||||
@@ -659,65 +590,11 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
self.rtc_processor: RTCProcessor | None = None
|
||||
self.action_tokenizer: Any | None = None
|
||||
self._load_hf_model()
|
||||
self._validate_inference_action_mode()
|
||||
self.config.validate_inference_action_mode(self._checkpoint_action_mode)
|
||||
if self.config.enable_lora_vlm:
|
||||
self._apply_lora_adapters()
|
||||
self.init_rtc_processor()
|
||||
|
||||
def _load_saved_policy_action_mode(self) -> str | None:
|
||||
pretrained_path = getattr(self.config, "pretrained_path", None)
|
||||
if pretrained_path is None:
|
||||
return None
|
||||
config_path = Path(pretrained_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
if mode in {"continuous", "discrete", "both"}:
|
||||
return str(mode)
|
||||
return None
|
||||
|
||||
def _training_action_mode(self) -> str:
|
||||
return getattr(self, "_checkpoint_action_mode", None) or self.config.action_mode
|
||||
|
||||
def _validate_inference_action_mode(self) -> None:
|
||||
requested_mode = self.config.inference_action_mode
|
||||
if requested_mode is None:
|
||||
return
|
||||
training_mode = self._training_action_mode()
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||
"continuous inference."
|
||||
)
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||
)
|
||||
|
||||
def _apply_norm_tag_metadata(self) -> None:
|
||||
if not str(self.config.norm_tag or "").strip():
|
||||
return
|
||||
metadata = _load_hf_norm_metadata_for_tag(
|
||||
self.config.checkpoint_path,
|
||||
revision=self.config.checkpoint_revision,
|
||||
force_download=bool(self.config.checkpoint_force_download),
|
||||
norm_tag=self.config.norm_tag,
|
||||
)
|
||||
if metadata.get("action_horizon") is not None:
|
||||
self.config.chunk_size = int(metadata["action_horizon"])
|
||||
if metadata.get("n_action_steps") is not None:
|
||||
self.config.n_action_steps = int(metadata["n_action_steps"])
|
||||
if not self.config.setup_type and metadata.get("setup_type") is not None:
|
||||
self.config.setup_type = str(metadata["setup_type"])
|
||||
if not self.config.control_mode and metadata.get("control_mode") is not None:
|
||||
self.config.control_mode = str(metadata["control_mode"])
|
||||
if not self.config.image_keys and isinstance(metadata.get("camera_keys"), list):
|
||||
self.config.image_keys = [str(key) for key in metadata["camera_keys"]]
|
||||
|
||||
def _load_hf_model(self) -> None:
|
||||
require_package("transformers", extra="molmoact2")
|
||||
|
||||
@@ -756,19 +633,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
"`policy.checkpoint_force_download=true` after the updated files are pushed."
|
||||
)
|
||||
checkpoint_action_mode = str(self.model.config.action_mode)
|
||||
if self.config.action_mode == "both" and checkpoint_action_mode != "both":
|
||||
raise ValueError(
|
||||
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||
f"got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.config.action_mode in {"continuous", "both"} and not bool(
|
||||
getattr(self.model.config, "add_action_expert", False)
|
||||
):
|
||||
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||
self.config.validate_checkpoint_action_mode(
|
||||
checkpoint_action_mode,
|
||||
has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)),
|
||||
)
|
||||
|
||||
if self.config.freeze_embedding:
|
||||
self._freeze_input_embeddings()
|
||||
@@ -1054,21 +922,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
return self.action_tokenizer
|
||||
|
||||
def _resolve_inference_action_mode(self, requested_mode: str | None) -> str:
|
||||
training_mode = self._training_action_mode()
|
||||
if requested_mode is None:
|
||||
requested_mode = self.config.inference_action_mode
|
||||
if requested_mode is None:
|
||||
raise ValueError(
|
||||
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||
"to either 'continuous' or 'discrete'."
|
||||
)
|
||||
if requested_mode not in {"continuous", "discrete"}:
|
||||
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||
return requested_mode
|
||||
return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode)
|
||||
|
||||
@staticmethod
|
||||
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
|
||||
|
||||
@@ -653,7 +653,7 @@ def test_select_action_uses_single_full_batch_queue():
|
||||
def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias():
|
||||
policy = object.__new__(MolmoAct2Policy)
|
||||
torch.nn.Module.__init__(policy)
|
||||
policy.config = SimpleNamespace(action_mode="both", inference_action_mode=None)
|
||||
policy.config = MolmoAct2Config(action_mode="both", inference_action_mode=None)
|
||||
policy._checkpoint_action_mode = None
|
||||
|
||||
with pytest.raises(ValueError, match="inference_action_mode.*explicitly"):
|
||||
@@ -1067,11 +1067,11 @@ def test_discrete_predict_action_chunk_uses_hf_cached_generation_path():
|
||||
|
||||
policy = object.__new__(MolmoAct2Policy)
|
||||
torch.nn.Module.__init__(policy)
|
||||
policy.config = SimpleNamespace(
|
||||
policy.config = MolmoAct2Config(
|
||||
action_mode="discrete",
|
||||
inference_action_mode="discrete",
|
||||
model_dtype="float32",
|
||||
output_features={ACTION: SimpleNamespace(shape=(2,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))},
|
||||
discrete_generation_max_steps=None,
|
||||
discrete_action_tokenizer="unused",
|
||||
trust_remote_code=True,
|
||||
@@ -1079,6 +1079,7 @@ def test_discrete_predict_action_chunk_uses_hf_cached_generation_path():
|
||||
n_action_steps=1,
|
||||
rtc_config=None,
|
||||
)
|
||||
policy._checkpoint_action_mode = None
|
||||
policy.model = DummyModel()
|
||||
policy.action_tokenizer = _DummyActionTokenizer()
|
||||
|
||||
@@ -1154,11 +1155,11 @@ def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled(
|
||||
|
||||
policy = object.__new__(MolmoAct2Policy)
|
||||
torch.nn.Module.__init__(policy)
|
||||
policy.config = SimpleNamespace(
|
||||
policy.config = MolmoAct2Config(
|
||||
action_mode="discrete",
|
||||
inference_action_mode="discrete",
|
||||
model_dtype="float32",
|
||||
output_features={ACTION: SimpleNamespace(shape=(2,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,))},
|
||||
discrete_generation_max_steps=None,
|
||||
discrete_action_tokenizer="unused",
|
||||
trust_remote_code=True,
|
||||
@@ -1167,6 +1168,7 @@ def test_discrete_predict_action_chunk_uses_graph_backed_ar_decode_when_enabled(
|
||||
rtc_config=None,
|
||||
enable_inference_cuda_graph=True,
|
||||
)
|
||||
policy._checkpoint_action_mode = None
|
||||
policy.model = DummyModel()
|
||||
policy.action_tokenizer = _DummyActionTokenizer()
|
||||
torch.nn.Module.train(policy, False)
|
||||
|
||||
Reference in New Issue
Block a user