format molmoact2 files

This commit is contained in:
hq-fang
2026-05-21 20:54:27 +00:00
parent 733f9768b5
commit b0cdf99957
8 changed files with 340 additions and 778 deletions
@@ -118,9 +118,9 @@ class UniversalActionProcessor(ProcessorMixin):
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert (
self.time_horizon is not None and self.action_dim is not None
), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
@@ -128,13 +128,12 @@ class UniversalActionProcessor(ProcessorMixin):
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert (
decoded_dct_coeff.shape
== (
self.time_horizon,
self.action_dim,
)
), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
@@ -162,9 +161,9 @@ class UniversalActionProcessor(ProcessorMixin):
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
assert (
min_vocab_size <= vocab_size
), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
assert min_vocab_size <= vocab_size, (
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
)
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
@@ -76,10 +76,7 @@ class MolmoAct2VitConfig(PretrainedConfig):
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(
attn_implementation=attn_implementation,
**kwargs
)
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
@@ -151,10 +148,7 @@ class MolmoAct2AdapterConfig(PretrainedConfig):
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(
attn_implementation=attn_implementation,
**kwargs
)
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.vit_layers = vit_layers
self.pooling_attention_mask = pooling_attention_mask
self.hidden_size = hidden_size
@@ -220,8 +214,8 @@ class MolmoAct2TextConfig(PretrainedConfig):
num_hidden_layers: int = 48,
intermediate_size: int = 18944,
hidden_act: str = "silu",
embedding_dropout: float=0.0,
attention_dropout: float=0.0,
embedding_dropout: float = 0.0,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
max_position_embeddings: int = 4096,
rope_theta: float = 1000000.0,
@@ -239,9 +233,7 @@ class MolmoAct2TextConfig(PretrainedConfig):
):
self.attn_implementation = attn_implementation
super().__init__(
tie_word_embeddings=tie_word_embeddings,
attn_implementation=attn_implementation,
**kwargs
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
@@ -17,6 +17,7 @@
# ruff: noqa
"""Image processor class for MolmoAct2"""
from typing import Optional, Union
import numpy as np
import einops
@@ -72,7 +73,9 @@ def resize_image(
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
@@ -97,10 +100,10 @@ def select_tiling(h, w, patch_size, max_num_crops):
tilings = []
for i in range(1, max_num_crops + 1):
for j in range(1, max_num_crops + 1):
if i*j <= max_num_crops:
if i * j <= max_num_crops:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
@@ -110,8 +113,8 @@ def select_tiling(h, w, patch_size, max_num_crops):
# The original size can be zero in rare cases if the image is smaller than the margin
# In those cases letting the scale become infinite means the tiling is based on the
# other side, or falls back to the smallest tiling
with np.errstate(divide='ignore'):
required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
with np.errstate(divide="ignore"):
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
@@ -132,14 +135,16 @@ def build_resized_image(
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image, base_image_input_size, resample,
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
@@ -184,7 +189,10 @@ def build_overlapping_crops(
src = resize_image(
image,
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
[
tiling[0] * crop_window_size + total_margin_pixels,
tiling[1] * crop_window_size + total_margin_pixels,
],
resample,
)
src = normalize_image(src, image_mean, image_std)
@@ -198,11 +206,11 @@ def build_overlapping_crops(
for i in range(tiling[0]):
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
# which results in overlapping crop windows
y0 = i*crop_window_size
y0 = i * crop_window_size
for j in range(tiling[1]):
x0 = j*crop_window_size
crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
x0 = j * crop_window_size
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
patch_idx += on_crop * crop_patch_h * crop_patch_w
# Mask out idx that are in the overlap region
@@ -210,27 +218,24 @@ def build_overlapping_crops(
patch_idx[:left_margin, :] = -1
if j != 0:
patch_idx[:, :left_margin] = -1
if i != tiling[0]-1:
if i != tiling[0] - 1:
patch_idx[-right_margin:, :] = -1
if j != tiling[1]-1:
if j != tiling[1] - 1:
patch_idx[:, -right_margin:] = -1
patch_idx_arr[on_crop] = patch_idx
on_crop += 1
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
# so it is ordered left-to-right order
patch_idx_arr = np.reshape(
patch_idx_arr,
[tiling[0], tiling[1], crop_patch_h, crop_patch_w]
)
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
# Now get the parts not in the overlap region, so it should map each patch in `src`
# to the correct patch it should come from in `crop_arr`
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
src.shape[0]//image_patch_size,
src.shape[1]//image_patch_size,
src.shape[0] // image_patch_size,
src.shape[1] // image_patch_size,
)
return crop_arr, patch_idx_arr
@@ -239,19 +244,19 @@ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h//patch_size
w_patches = w//patch_size
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h//patch_size
w_patches = w//patch_size
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
@@ -262,10 +267,13 @@ def arange_for_pooling(
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
mode='constant',constant_values=-1)
return einops.rearrange(
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
@@ -330,7 +338,7 @@ def image_to_patches_and_grids(
)
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
# Finally do the same for the global image
resized, resize_idx = build_resized_image(
@@ -345,22 +353,14 @@ def image_to_patches_and_grids(
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
# Global image goes first, so the order of patches in previous crops gets increased
pooling_idx = np.where(
pooling_idx >= 0,
pooling_idx + crop_patch_h*crop_patch_w,
-1
)
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
pooling_idx = np.concatenate([resize_idx, pooling_idx])
image_grid = [np.array([resized_h, resized_w, h, w])]
return (
np.stack(image_grid, 0),
batch_pixels_to_patches(crop_arr, image_patch_size),
pooling_idx
)
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
@@ -144,9 +144,7 @@ class _DepthDecodeStaticLayerCache:
start = self.cumulative_length
end = start + key_states.shape[-2]
if end > self.max_cache_len:
raise RuntimeError(
f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
)
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
self.keys[:, :, start:end, :].copy_(key_states)
self.values[:, :, start:end, :].copy_(value_states)
self.cumulative_length = end
@@ -306,26 +304,15 @@ class DepthDecodeCudaGraphManager:
past_key_values: Cache,
attention_bias: torch.Tensor,
) -> bool:
if (
not self.enabled
or self.model.training
or self.backbone.transformer.training
):
if not self.enabled or self.model.training or self.backbone.transformer.training:
return False
if next_input_ids.device.type != "cuda":
return False
if (
next_input_ids.ndim != 2
or next_input_ids.shape[0] != 1
or next_input_ids.shape[1] != 1
):
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
return False
if not isinstance(past_key_values, _DepthDecodeStaticCache):
return False
if (
not torch.is_tensor(attention_bias)
or attention_bias.device != next_input_ids.device
):
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
return False
return self._depth_decode_spec().eligible
@@ -343,9 +330,7 @@ class DepthDecodeCudaGraphManager:
attention_bias.shape[-1],
)
def _select_depth_decode_rope(
self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
) -> None:
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
emb = self.backbone.transformer.rotary_emb
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
@@ -385,9 +370,7 @@ class DepthDecodeCudaGraphManager:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states, key_states = _apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
return residual, query_states, key_states, value_states
def _depth_decode_pre0(
@@ -453,9 +436,7 @@ class DepthDecodeCudaGraphManager:
head_dim = static.head_dim
max_cache_len = int(attention_bias.shape[-1])
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
self.backbone.transformer.prepare_rope_cache(
device=device, max_seq_len=max_rope_len
)
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
@@ -487,9 +468,7 @@ class DepthDecodeCudaGraphManager:
),
device,
)
post_graphs.append(
_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
last_stage = stages[-1]
@@ -502,11 +481,7 @@ class DepthDecodeCudaGraphManager:
),
device,
)
post_graphs.append(
_DepthDecodeCudaGraphPostStage(
graph=last_graph, attn_context=last_attn_context
)
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
return _DepthDecodeCudaGraph(
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
pre_graph=pre_graph,
@@ -537,9 +512,7 @@ class DepthDecodeCudaGraphManager:
self.graph = decode_graph
else:
decode_graph.token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(
decode_graph.cos, decode_graph.sin, past_length=past_length
)
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
return decode_graph
def _run_depth_decode_attention_core(
@@ -628,9 +601,7 @@ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
sig(context.cross_mask),
sig(context.self_mask),
sig(context.valid_action),
None
if context.rope_cache is None
else tuple(sig(t) for t in context.rope_cache),
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
)
@@ -639,10 +610,7 @@ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, .
return tuple(
(
sig(step.conditioning),
tuple(
tuple(sig(t) for t in block_modulation)
for block_modulation in step.block_modulations
),
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
tuple(sig(t) for t in step.final_modulation),
)
for step in modulations
@@ -678,10 +646,7 @@ def _clone_static_context(context: Any) -> Any:
if context.rope_cache is not None:
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
return context.__class__(
kv_contexts=tuple(
(_clone_static_tensor(k), _clone_static_tensor(v))
for k, v in context.kv_contexts
),
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
cross_mask=_clone_static_tensor(context.cross_mask),
self_mask=_clone_static_tensor(context.self_mask),
valid_action=_clone_static_tensor(context.valid_action),
@@ -697,9 +662,7 @@ def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
tuple(_clone_static_tensor(t) for t in block_modulation)
for block_modulation in step.block_modulations
),
final_modulation=tuple(
_clone_static_tensor(t) for t in step.final_modulation
),
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
)
for step in modulations
)
@@ -760,9 +723,7 @@ def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
File diff suppressed because it is too large Load Diff
@@ -19,6 +19,7 @@
"""
Processor class for MolmoAct2.
"""
from typing import Optional, Union
import dataclasses
@@ -50,7 +51,7 @@ IM_START_TOKEN = f"<im_start>"
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
FRAME_START_TOKEN = f"<frame_start>"
IM_END_TOKEN = f"<im_end>"
FRAME_END_TOKEN= f"<frame_end>"
FRAME_END_TOKEN = f"<frame_end>"
IM_COL_TOKEN = f"<im_col>"
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
@@ -69,6 +70,7 @@ IMAGE_TOKENS = [
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
"""MolmoAct2 processor kwargs"""
images_kwargs: MolmoAct2ImagesKwargs
videos_kwargs: MolmoAct2VideoProcessorKwargs
_defaults = {
@@ -106,7 +108,7 @@ class MolmoAct2Processor(ProcessorMixin):
use_single_crop_start_token: Optional[bool] = True,
video_use_col_tokens: Optional[bool] = False,
use_frame_special_tokens: Optional[bool] = True,
**kwargs
**kwargs,
) -> None:
super().__init__(
image_processor,
@@ -122,10 +124,7 @@ class MolmoAct2Processor(ProcessorMixin):
self.image_placeholder_token = IMAGE_PROMPT
self.video_placeholder_token = VIDEO_PROMPT
self.image_token_ids = [
tokenizer.convert_tokens_to_ids(token)
for token in IMAGE_TOKENS
]
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
def get_image_tokens(self, image_grid: np.ndarray):
resized_h, resized_w, height, width = image_grid
@@ -158,11 +157,7 @@ class MolmoAct2Processor(ProcessorMixin):
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
image_start_token = (
LOW_RES_IMAGE_START_TOKEN
if self.use_single_crop_start_token
else IM_START_TOKEN
)
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
@@ -190,7 +185,7 @@ class MolmoAct2Processor(ProcessorMixin):
for frame_idx, frame_time in enumerate(timestamps):
# `per-frame-compact` time mode
prev_space = " " if frame_idx > 0 else ""
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
video_string += frame_prefix
per_row = np.full(w, IMAGE_PATCH_TOKEN)
@@ -249,8 +244,8 @@ class MolmoAct2Processor(ProcessorMixin):
attention_mask = attention_mask[0]
return input_ids, attention_mask
else:
new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
@@ -349,13 +344,13 @@ class MolmoAct2Processor(ProcessorMixin):
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
text = text.copy() # below lines change text in-place
if image_grids is not None:
index = 0
for i in range(len(text)):
num_images = text[i].count(self.image_placeholder_token)
image_grids_i = image_grids[index:index+num_images]
image_grids_i = image_grids[index : index + num_images]
for image_grid in image_grids_i:
image_tokens = self.get_image_tokens(image_grid)
image_string = "".join(image_tokens)
@@ -367,8 +362,8 @@ class MolmoAct2Processor(ProcessorMixin):
for i in range(len(text)):
num_videos = text[i].count(self.video_placeholder_token)
assert num_videos in {0, 1}, "At most one video is supported for now"
video_grids_i = video_grids[index:index+num_videos]
metadata_i = video_metadata[index:index+num_videos]
video_grids_i = video_grids[index : index + num_videos]
metadata_i = video_metadata[index : index + num_videos]
for video_grid, metadata in zip(video_grids_i, metadata_i):
video_string = self.get_video_string(
video_grid,
@@ -17,6 +17,7 @@
# ruff: noqa
"""Video processor class for MolmoAct2"""
from functools import partial
import os
import warnings
@@ -100,7 +101,9 @@ def resize_image(
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
@@ -130,14 +133,16 @@ def build_resized_image(
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image, base_image_input_size, resample,
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
@@ -145,19 +150,19 @@ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h//patch_size
w_patches = w//patch_size
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h//patch_size
w_patches = w//patch_size
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
@@ -168,10 +173,13 @@ def arange_for_pooling(
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
mode='constant',constant_values=-1)
return einops.rearrange(
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
@@ -206,7 +214,7 @@ def image_to_patches_and_grids(
)
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [h, w]
return (
image_grid,
@@ -277,6 +285,7 @@ def read_video_decord(
"""
# Lazy import from decord
import importlib
decord = importlib.import_module("decord")
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
@@ -296,7 +305,7 @@ def read_video_decord(
target_timestamps = np.array(target_timestamps)
offset = time_stamps[0, 0]
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right')
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
ix = np.minimum(ix, len(time_stamps) - 1)
video = vr.get_batch(ix).asnumpy()
@@ -331,6 +340,7 @@ def read_video_torchcodec(
"""
# Lazy import torchcodec
import importlib
torchcodec = importlib.import_module("torchcodec")
decoder = torchcodec.decoders.VideoDecoder(
@@ -360,7 +370,7 @@ def read_video_torchcodec(
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
# out-of-bounds, to handle this we sanity check then clip them
assert all(x >= 0 for x in target_timestamps)
assert all(x < duration+1e-6 for x in target_timestamps)
assert all(x < duration + 1e-6 for x in target_timestamps)
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
# exact boundary value, we should still get the first/last frame anyway
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
@@ -369,7 +379,9 @@ def read_video_torchcodec(
timestamps = [x + time_offset for x in target_timestamps]
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format
video = (
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
) # Convert to THWC format
target_timestamps = np.array(target_timestamps)
metadata.frames_indices = target_timestamps * metadata.fps
@@ -397,6 +409,7 @@ def read_video_pyav(
"""
# Lazy import torchcodec
import importlib
av = importlib.import_module("av")
with av.open(video_path) as container:
@@ -413,7 +426,7 @@ def read_video_pyav(
if container_end is None or container_end < frames[-1].pts:
# Some problem with stream duration, so use the frame PTS directly
# and guess the duration of the last frame
end = frames[-1].pts * stream.time_base + 1/fps
end = frames[-1].pts * stream.time_base + 1 / fps
else:
end = container_end
duration = float(end - start)
@@ -432,7 +445,7 @@ def read_video_pyav(
target_timestamps = np.array(target_timestamps)
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right')
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
indices = np.minimum(indices, len(end_time_stamps) - 1)
video = np.stack(
@@ -480,6 +493,7 @@ def load_video(
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
import importlib
yt_dlp = importlib.import_module("yt_dlp")
buffer = BytesIO()
@@ -492,7 +506,9 @@ def load_video(
elif os.path.isfile(video):
file_obj = video
else:
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
raise TypeError(
"Incorrect format used for video. Should be an url linking to an video or a local path."
)
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
@@ -551,12 +567,7 @@ def get_target_fps(
return selected_target_fps
def get_frame_times_and_chosen_fps(
selected_target_fps,
total_frames,
max_frames,
video_fps
):
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
if selected_target_fps is None:
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
else:
@@ -656,19 +667,15 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (num_frames-1) / max_fps # -1 to include the last frame
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(
0, duration, num=num_frames, endpoint=True, dtype=np.float64
)
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
else:
times = np.arange(0.0, stop=duration, step=1/max_fps)
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= num_frames
else:
times = np.linspace(
0, duration, num=num_frames, endpoint=True, dtype=np.float64
)
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
return times
else:
raise NotImplementedError(frame_sample_mode)
@@ -717,7 +724,9 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
return indices
else:
float_indices = np.arange(
0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps),
0.0,
stop=total_num_frames - 1,
step=float(metadata.fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
@@ -727,7 +736,10 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
return indices
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True,
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
elif frame_sample_mode == "fps":
@@ -750,9 +762,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
raise NotImplementedError(frame_sample_mode)
def fetch_videos(
self,
video_url_or_urls: Union[str, list[str], list[list[str]]],
sample_timestamps_fn=None
self, video_url_or_urls: Union[str, list[str], list[list[str]]], sample_timestamps_fn=None
):
"""
Convert a single or a list of urls into the corresponding `np.array` objects.
@@ -760,11 +770,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
if (
(not is_decord_available())
and (not is_torchcodec_available())
and (not is_av_available())
):
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
raise ImportError(
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
)
@@ -785,7 +791,14 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
backend = "pyav"
if isinstance(video_url_or_urls, list):
return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls]))
return list(
zip(
*[
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
for x in video_url_or_urls
]
)
)
else:
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
@@ -823,9 +836,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
)
if isinstance(videos[0], list):
raise ValueError(
"A list of images is not supported for video input!"
)
raise ValueError("A list of images is not supported for video input!")
else:
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
@@ -975,7 +986,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
pixel_values_videos = np.concatenate(batch_crops, 0)
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
data =dict(
data = dict(
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
@@ -136,7 +136,6 @@ def _sample_beta_timesteps(
return time_offset + scale * samples
class MolmoAct2Policy(PreTrainedPolicy):
config_class = MolmoAct2Config
name = "molmoact2"