fix molmoact2 pre-commit checks

This commit is contained in:
hq-fang
2026-05-22 07:25:18 +00:00
parent 0a369e104a
commit dca792951e
7 changed files with 391 additions and 394 deletions
+2
View File
@@ -405,6 +405,8 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"arange",
"is_compileable",
"ROBOTIS",
"OT_VALUE",
"VanderBilt"
@@ -206,7 +206,7 @@ class MolmoAct2TextConfig(PretrainedConfig):
self,
hidden_size: int = 3584,
num_attention_heads: int = 28,
num_key_value_heads: Optional[int] = 4,
num_key_value_heads: int | None = 4,
head_dim: int = 128,
vocab_size: int = 152064,
additional_vocab_size: int = 128,
@@ -220,7 +220,7 @@ class MolmoAct2TextConfig(PretrainedConfig):
max_position_embeddings: int = 4096,
rope_theta: float = 1000000.0,
rope_scaling: dict[str, Any] = None,
rope_scaling_layers: Optional[list[int]] = None,
rope_scaling_layers: list[int] | None = None,
use_qk_norm: bool = False,
qk_norm_type: str = "olmo",
layer_norm_eps: int = 1e-6,
@@ -364,11 +364,11 @@ def image_to_patches_and_grids(
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
max_crops: Optional[int]
overlap_margins: Optional[list[int]]
crop_mode: Optional[str]
patch_size: Optional[int]
pooling_size: Optional[list[int]]
max_crops: int | None
overlap_margins: list[int] | None
crop_mode: str | None
patch_size: int | None
pooling_size: list[int] | None
class MolmoAct2ImageProcessor(BaseImageProcessor):
@@ -400,10 +400,10 @@ class MolmoAct2ImageProcessor(BaseImageProcessor):
def __init__(
self,
size: Optional[dict[str, int]] = None,
size: dict[str, int] | None = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
max_crops: int = 8,
overlap_margins: list[int] = [4, 4],
@@ -431,17 +431,17 @@ class MolmoAct2ImageProcessor(BaseImageProcessor):
def preprocess(
self,
images: ImageInput,
size: Optional[dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_convert_rgb: Optional[bool] = None,
max_crops: Optional[int] = None,
overlap_margins: Optional[list[int]] = None,
crop_mode: Optional[str] = None,
patch_size: Optional[int] = None,
pooling_size: Optional[list[int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
size: dict[str, int] | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
max_crops: int | None = None,
overlap_margins: list[int] | None = None,
crop_mode: str | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
@@ -19,7 +19,8 @@
"""Inference utilities for MolmoAct2"""
from dataclasses import dataclass
from typing import Any, Iterable, Optional, Sequence, Tuple
from typing import Any, Optional, Tuple
from collections.abc import Iterable, Sequence
import torch
from torch.nn import functional as F
@@ -32,12 +33,12 @@ class _ActionFlowInputs:
trajectory: torch.Tensor
context: Any
modulations: Sequence[Any]
action_dim_is_pad: Optional[torch.Tensor]
action_dim_is_pad: torch.Tensor | None
@dataclass
class _ActionFlowCudaGraph:
key: Tuple[Any, ...]
key: tuple[Any, ...]
graph: torch.cuda.CUDAGraph
static_inputs: _ActionFlowInputs
output: torch.Tensor
@@ -59,7 +60,7 @@ class _DepthDecodeCudaGraphPostStage:
@dataclass
class _DepthDecodeCudaGraph:
cache_key: Tuple[Any, ...]
cache_key: tuple[Any, ...]
pre_graph: torch.cuda.CUDAGraph
token_ids: torch.Tensor
cos: torch.Tensor
@@ -73,13 +74,13 @@ class _DepthDecodeCudaGraph:
@dataclass
class _DepthDecodeCudaGraphSpec:
eligible: bool
cache_key_prefix: Tuple[Any, ...]
cache_key_prefix: tuple[Any, ...]
num_hidden_layers: int
head_dim: int
num_attention_heads: int
def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return 0
seq_len = past_key_values.get_seq_length()
@@ -88,7 +89,7 @@ def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
return int(seq_len)
def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
def _cache_max_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return -1
max_len = past_key_values.get_max_cache_shape()
@@ -99,7 +100,7 @@ def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
def _iter_cache_key_values(
past_key_values: Cache,
) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
layers = getattr(past_key_values, "layers", None)
if layers is not None:
for layer in layers:
@@ -116,8 +117,8 @@ class _DepthDecodeStaticLayerCache:
def __init__(self, max_cache_len: int) -> None:
self.max_cache_len = int(max_cache_len)
self.cumulative_length = 0
self.keys: Optional[torch.Tensor] = None
self.values: Optional[torch.Tensor] = None
self.keys: torch.Tensor | None = None
self.values: torch.Tensor | None = None
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
bsz, n_heads = key_states.shape[:2]
@@ -138,7 +139,7 @@ class _DepthDecodeStaticLayerCache:
value_states: torch.Tensor,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
if self.keys is None:
self._allocate(key_states, value_states)
start = self.cumulative_length
@@ -185,7 +186,7 @@ class ActionCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.enabled = True
self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
self.action_flow_graph: _ActionFlowCudaGraph | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
@@ -256,8 +257,8 @@ class DepthDecodeCudaGraphManager:
self.model = model
self.backbone = model.model
self.enabled = True
self.graph: Optional[_DepthDecodeCudaGraph] = None
self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
self.graph: _DepthDecodeCudaGraph | None = None
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
@@ -320,7 +321,7 @@ class DepthDecodeCudaGraphManager:
self,
next_input_ids: torch.Tensor,
attention_bias: torch.Tensor,
) -> Tuple[Any, ...]:
) -> tuple[Any, ...]:
device = next_input_ids.device
return (
self._depth_decode_spec().cache_key_prefix,
@@ -341,7 +342,7 @@ class DepthDecodeCudaGraphManager:
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
residual = hidden_states
@@ -378,7 +379,7 @@ class DepthDecodeCudaGraphManager:
token_ids: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
inputs_embeds = self.model._embed_base_tokens(token_ids)
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
@@ -408,7 +409,7 @@ class DepthDecodeCudaGraphManager:
attn_context: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
@@ -553,7 +554,7 @@ class DepthDecodeCudaGraphManager:
past_key_values: Cache,
attention_bias: torch.Tensor,
past_length: int,
) -> Tuple[torch.Tensor, Cache]:
) -> tuple[torch.Tensor, Cache]:
end = past_length + 1
decode_graph = self._get_depth_decode_graph(
next_input_ids,
@@ -582,8 +583,8 @@ class DepthDecodeCudaGraphManager:
def _cuda_graph_tensor_signature(
tensor: Optional[torch.Tensor],
) -> Optional[Tuple[Any, ...]]:
tensor: torch.Tensor | None,
) -> tuple[Any, ...] | None:
if tensor is None:
return None
return (
@@ -594,7 +595,7 @@ def _cuda_graph_tensor_signature(
)
def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
@@ -605,7 +606,7 @@ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
)
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return tuple(
(
@@ -617,7 +618,7 @@ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, .
)
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
sig(inputs.trajectory),
@@ -628,7 +629,7 @@ def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
)
def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None:
return None
static = torch.empty_strided(
@@ -711,7 +712,7 @@ def _apply_rotary_pos_emb(
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (_rotate_half(q) * sin)
@@ -732,7 +733,7 @@ def _capture_cuda_graph(
device: torch.device,
*,
after_warmup=None,
) -> Tuple[torch.cuda.CUDAGraph, Any]:
) -> tuple[torch.cuda.CUDAGraph, Any]:
warmup_stream = torch.cuda.Stream(device=device)
warmup_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(warmup_stream):
File diff suppressed because it is too large Load Diff
@@ -102,12 +102,12 @@ class MolmoAct2Processor(ProcessorMixin):
image_processor: MolmoAct2ImageProcessor = None,
video_processor: MolmoAct2VideoProcessor = None,
tokenizer: AutoTokenizer = None,
chat_template: Optional[str] = None,
image_use_col_tokens: Optional[bool] = True,
use_single_crop_col_tokens: Optional[bool] = None,
use_single_crop_start_token: Optional[bool] = True,
video_use_col_tokens: Optional[bool] = False,
use_frame_special_tokens: Optional[bool] = True,
chat_template: str | None = None,
image_use_col_tokens: bool | None = True,
use_single_crop_col_tokens: bool | None = None,
use_single_crop_start_token: bool | None = True,
video_use_col_tokens: bool | None = False,
use_frame_special_tokens: bool | None = True,
**kwargs,
) -> None:
super().__init__(
@@ -272,7 +272,7 @@ class MolmoAct2Processor(ProcessorMixin):
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
images: ImageInput = None,
videos: VideoInput = None,
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
@@ -24,7 +24,8 @@ import warnings
from contextlib import redirect_stdout
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Union, Callable
from typing import Optional, Union
from collections.abc import Callable
import numpy as np
import requests
@@ -224,9 +225,9 @@ def image_to_patches_and_grids(
def get_candidate_target_fps(
video_fps: Union[int, float],
sampling_fps: Union[int, float],
max_fps: Union[int, float] = MAX_VIDEO_FPS,
video_fps: int | float,
sampling_fps: int | float,
max_fps: int | float = MAX_VIDEO_FPS,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
@@ -468,7 +469,7 @@ VIDEO_DECODERS = {
def load_video(
video: VideoInput,
backend: str = "decord",
sample_timestamps_fn: Optional[Callable] = None,
sample_timestamps_fn: Callable | None = None,
**kwargs,
):
"""
@@ -502,7 +503,7 @@ def load_video(
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video).content)
file_obj = BytesIO(requests.get(video, timeout=10).content)
elif os.path.isfile(video):
file_obj = video
else:
@@ -579,11 +580,11 @@ def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
patch_size: Optional[int]
pooling_size: Optional[list[int]]
frame_sample_mode: Optional[str]
max_fps: Optional[int]
sampling_fps: Optional[int]
patch_size: int | None
pooling_size: list[int] | None
frame_sample_mode: str | None
max_fps: int | None
sampling_fps: int | None
class MolmoAct2VideoProcessor(BaseVideoProcessor):
@@ -613,7 +614,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
def _further_process_kwargs(
self,
size: Optional[SizeDict] = None,
size: SizeDict | None = None,
**kwargs,
) -> dict:
"""
@@ -630,8 +631,8 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
metadata: VideoMetadata,
frame_sample_mode: str,
num_frames: int,
max_fps: Optional[int] = None,
sampling_fps: Optional[int] = None,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
@@ -683,10 +684,10 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
def sample_frames(
self,
metadata: VideoMetadata,
frame_sample_mode: Optional[str] = None,
num_frames: Optional[int] = None,
max_fps: Optional[int] = None,
sampling_fps: Optional[int] = None,
frame_sample_mode: str | None = None,
num_frames: int | None = None,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
@@ -761,9 +762,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
else:
raise NotImplementedError(frame_sample_mode)
def fetch_videos(
self, video_url_or_urls: Union[str, list[str], list[list[str]]], sample_timestamps_fn=None
):
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
"""
Convert a single or a list of urls into the corresponding `np.array` objects.
@@ -805,10 +804,10 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
def _decode_and_sample_videos(
self,
videos: VideoInput,
video_metadata: Union[VideoMetadata, dict],
do_sample_frames: Optional[bool] = None,
sample_indices_fn: Optional[Callable] = None,
sample_timestamps_fn: Optional[Callable] = None,
video_metadata: VideoMetadata | dict,
do_sample_frames: bool | None = None,
sample_indices_fn: Callable | None = None,
sample_timestamps_fn: Callable | None = None,
):
"""
Decode input videos and sample frames if needed.
@@ -890,14 +889,14 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
def _preprocess(
self,
videos: list[np.ndarray],
size: Optional[SizeDict] = None,
resample: Optional[PILImageResampling] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_convert_rgb: Optional[bool] = None,
patch_size: Optional[int] = None,
pooling_size: Optional[list[int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
size: SizeDict | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""