From b0cdf999579cf83bc15f90eabd552c2780b2845b Mon Sep 17 00:00:00 2001 From: hq-fang <71356829+hq-fang@users.noreply.github.com> Date: Thu, 21 May 2026 20:54:27 +0000 Subject: [PATCH] format molmoact2 files --- .../molmoact2/hf_model/action_tokenizer.py | 25 +- .../hf_model/configuration_molmoact2.py | 18 +- .../hf_model/image_processing_molmoact2.py | 84 +- .../policies/molmoact2/hf_model/inference.py | 69 +- .../molmoact2/hf_model/modeling_molmoact2.py | 781 +++++------------- .../hf_model/processing_molmoact2.py | 31 +- .../hf_model/video_processing_molmoact2.py | 109 +-- .../policies/molmoact2/modeling_molmoact2.py | 1 - 8 files changed, 340 insertions(+), 778 deletions(-) diff --git a/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py b/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py index fad5c1fc1..f7dacbce6 100644 --- a/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py +++ b/src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py @@ -118,9 +118,9 @@ class UniversalActionProcessor(ProcessorMixin): self.called_time_horizon = self.time_horizon self.called_action_dim = self.action_dim - assert ( - self.time_horizon is not None and self.action_dim is not None - ), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + assert self.time_horizon is not None and self.action_dim is not None, ( + "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + ) decoded_actions = [] for token in tokens: @@ -128,13 +128,12 @@ class UniversalActionProcessor(ProcessorMixin): decoded_tokens = self.bpe_tokenizer.decode(token) decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) - assert ( - decoded_dct_coeff.shape - == ( - self.time_horizon, - self.action_dim, - ) - ), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + assert decoded_dct_coeff.shape == ( + self.time_horizon, + self.action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + ) except Exception as e: print(f"Error decoding tokens: {e}") print(f"Tokens: {token}") @@ -162,9 +161,9 @@ class UniversalActionProcessor(ProcessorMixin): min_token = int(np.around(np.concatenate(dct_tokens) * scale).min()) min_vocab_size = max_token - min_token - assert ( - min_vocab_size <= vocab_size - ), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}" + assert min_vocab_size <= vocab_size, ( + f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}" + ) if min_vocab_size + 100 > vocab_size: logging.warning( f"Initial alphabet size {min_vocab_size} is almost as large as the vocab" diff --git a/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py index f74a36837..c66c81fe0 100644 --- a/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py @@ -76,10 +76,7 @@ class MolmoAct2VitConfig(PretrainedConfig): **kwargs, ): self.attn_implementation = attn_implementation - super().__init__( - attn_implementation=attn_implementation, - **kwargs - ) + super().__init__(attn_implementation=attn_implementation, **kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers @@ -151,10 +148,7 @@ class MolmoAct2AdapterConfig(PretrainedConfig): **kwargs, ): self.attn_implementation = attn_implementation - super().__init__( - attn_implementation=attn_implementation, - **kwargs - ) + super().__init__(attn_implementation=attn_implementation, **kwargs) self.vit_layers = vit_layers self.pooling_attention_mask = pooling_attention_mask self.hidden_size = hidden_size @@ -220,8 +214,8 @@ class MolmoAct2TextConfig(PretrainedConfig): num_hidden_layers: int = 48, intermediate_size: int = 18944, hidden_act: str = "silu", - embedding_dropout: float=0.0, - attention_dropout: float=0.0, + embedding_dropout: float = 0.0, + attention_dropout: float = 0.0, residual_dropout: float = 0.0, max_position_embeddings: int = 4096, rope_theta: float = 1000000.0, @@ -239,9 +233,7 @@ class MolmoAct2TextConfig(PretrainedConfig): ): self.attn_implementation = attn_implementation super().__init__( - tie_word_embeddings=tie_word_embeddings, - attn_implementation=attn_implementation, - **kwargs + tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs ) self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads diff --git a/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py index b35eab286..16d2afec9 100644 --- a/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py @@ -17,6 +17,7 @@ # ruff: noqa """Image processor class for MolmoAct2""" + from typing import Optional, Union import numpy as np import einops @@ -72,7 +73,9 @@ def resize_image( )(image) resized = torch.clip(resized, 0.0, 1.0).to(dtype) else: - assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype) + assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format( + image.dtype + ) in_min = 0.0 in_max = 255.0 resized = torchvision.transforms.Resize( @@ -97,10 +100,10 @@ def select_tiling(h, w, patch_size, max_num_crops): tilings = [] for i in range(1, max_num_crops + 1): for j in range(1, max_num_crops + 1): - if i*j <= max_num_crops: + if i * j <= max_num_crops: tilings.append((i, j)) # sort so argmin and argmax favour smaller tilings in the event of a tie - tilings.sort(key=lambda x: (x[0]*x[1], x[0])) + tilings.sort(key=lambda x: (x[0] * x[1], x[0])) candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2] candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2] @@ -110,8 +113,8 @@ def select_tiling(h, w, patch_size, max_num_crops): # The original size can be zero in rare cases if the image is smaller than the margin # In those cases letting the scale become infinite means the tiling is based on the # other side, or falls back to the smallest tiling - with np.errstate(divide='ignore'): - required_scale_d = candidate_resolutions.astype(np.float32) / original_size, + with np.errstate(divide="ignore"): + required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,) required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1] if np.all(required_scale < 1): # We are forced to downscale, so try to minimize the amount of downscaling @@ -132,14 +135,16 @@ def build_resized_image( image_patch_size: int, ) -> tuple[np.ndarray, np.ndarray]: resized = resize_image( - image, base_image_input_size, resample, + image, + base_image_input_size, + resample, ) resized = normalize_image(resized, image_mean, image_std) if len(resized.shape) == 3: resized = np.expand_dims(resized, 0) crop_patch_w = base_image_input_size[1] // image_patch_size crop_patch_h = base_image_input_size[0] // image_patch_size - resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w]) + resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w]) return resized, resize_idx @@ -184,7 +189,10 @@ def build_overlapping_crops( src = resize_image( image, - [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels], + [ + tiling[0] * crop_window_size + total_margin_pixels, + tiling[1] * crop_window_size + total_margin_pixels, + ], resample, ) src = normalize_image(src, image_mean, image_std) @@ -198,11 +206,11 @@ def build_overlapping_crops( for i in range(tiling[0]): # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size` # which results in overlapping crop windows - y0 = i*crop_window_size + y0 = i * crop_window_size for j in range(tiling[1]): - x0 = j*crop_window_size - crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size] - patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w) + x0 = j * crop_window_size + crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size] + patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w) patch_idx += on_crop * crop_patch_h * crop_patch_w # Mask out idx that are in the overlap region @@ -210,27 +218,24 @@ def build_overlapping_crops( patch_idx[:left_margin, :] = -1 if j != 0: patch_idx[:, :left_margin] = -1 - if i != tiling[0]-1: + if i != tiling[0] - 1: patch_idx[-right_margin:, :] = -1 - if j != tiling[1]-1: + if j != tiling[1] - 1: patch_idx[:, -right_margin:] = -1 patch_idx_arr[on_crop] = patch_idx on_crop += 1 # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr` # so it is ordered left-to-right order - patch_idx_arr = np.reshape( - patch_idx_arr, - [tiling[0], tiling[1], crop_patch_h, crop_patch_w] - ) + patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w]) patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3]) patch_idx_arr = np.reshape(patch_idx_arr, [-1]) # Now get the parts not in the overlap region, so it should map each patch in `src` # to the correct patch it should come from in `crop_arr` patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape( - src.shape[0]//image_patch_size, - src.shape[1]//image_patch_size, + src.shape[0] // image_patch_size, + src.shape[1] // image_patch_size, ) return crop_arr, patch_idx_arr @@ -239,19 +244,19 @@ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray: """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]""" if len(array.shape) == 3: n_crops, h, w = array.shape - h_patches = h//patch_size - w_patches = w//patch_size + h_patches = h // patch_size + w_patches = w // patch_size array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size]) array = np.transpose(array, [0, 1, 3, 2, 4]) - array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size]) return array else: n_crops, h, w, c = array.shape - h_patches = h//patch_size - w_patches = w//patch_size + h_patches = h // patch_size + w_patches = w // patch_size array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c]) array = np.transpose(array, [0, 1, 3, 2, 4, 5]) - array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c]) return array @@ -262,10 +267,13 @@ def arange_for_pooling( ) -> np.ndarray: h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0] w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1] - idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]], - mode='constant',constant_values=-1) - return einops.rearrange( - idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) + idx_arr = np.pad( + idx_arr, + [[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]], + mode="constant", + constant_values=-1, + ) + return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) def image_to_patches_and_grids( @@ -330,7 +338,7 @@ def image_to_patches_and_grids( ) pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w) h, w = pooling_idx.shape[:2] - pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w]) + pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w]) # Finally do the same for the global image resized, resize_idx = build_resized_image( @@ -345,22 +353,14 @@ def image_to_patches_and_grids( resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) resized_h, resized_w = resize_idx.shape[:2] - resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w]) + resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w]) # Global image goes first, so the order of patches in previous crops gets increased - pooling_idx = np.where( - pooling_idx >= 0, - pooling_idx + crop_patch_h*crop_patch_w, - -1 - ) + pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1) pooling_idx = np.concatenate([resize_idx, pooling_idx]) image_grid = [np.array([resized_h, resized_w, h, w])] - return ( - np.stack(image_grid, 0), - batch_pixels_to_patches(crop_arr, image_patch_size), - pooling_idx - ) + return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx) class MolmoAct2ImagesKwargs(ImagesKwargs, total=False): diff --git a/src/lerobot/policies/molmoact2/hf_model/inference.py b/src/lerobot/policies/molmoact2/hf_model/inference.py index 21949e03b..1bfcb8178 100644 --- a/src/lerobot/policies/molmoact2/hf_model/inference.py +++ b/src/lerobot/policies/molmoact2/hf_model/inference.py @@ -144,9 +144,7 @@ class _DepthDecodeStaticLayerCache: start = self.cumulative_length end = start + key_states.shape[-2] if end > self.max_cache_len: - raise RuntimeError( - f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}." - ) + raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.") self.keys[:, :, start:end, :].copy_(key_states) self.values[:, :, start:end, :].copy_(value_states) self.cumulative_length = end @@ -306,26 +304,15 @@ class DepthDecodeCudaGraphManager: past_key_values: Cache, attention_bias: torch.Tensor, ) -> bool: - if ( - not self.enabled - or self.model.training - or self.backbone.transformer.training - ): + if not self.enabled or self.model.training or self.backbone.transformer.training: return False if next_input_ids.device.type != "cuda": return False - if ( - next_input_ids.ndim != 2 - or next_input_ids.shape[0] != 1 - or next_input_ids.shape[1] != 1 - ): + if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1: return False if not isinstance(past_key_values, _DepthDecodeStaticCache): return False - if ( - not torch.is_tensor(attention_bias) - or attention_bias.device != next_input_ids.device - ): + if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device: return False return self._depth_decode_spec().eligible @@ -343,9 +330,7 @@ class DepthDecodeCudaGraphManager: attention_bias.shape[-1], ) - def _select_depth_decode_rope( - self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int - ) -> None: + def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None: emb = self.backbone.transformer.rotary_emb cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :]) sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :]) @@ -385,9 +370,7 @@ class DepthDecodeCudaGraphManager: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - query_states, key_states = _apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) + query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin) return residual, query_states, key_states, value_states def _depth_decode_pre0( @@ -453,9 +436,7 @@ class DepthDecodeCudaGraphManager: head_dim = static.head_dim max_cache_len = int(attention_bias.shape[-1]) max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len) - self.backbone.transformer.prepare_rope_cache( - device=device, max_seq_len=max_rope_len - ) + self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len) token_ids = torch.empty((1, 1), device=device, dtype=torch.long) cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype) @@ -487,9 +468,7 @@ class DepthDecodeCudaGraphManager: ), device, ) - post_graphs.append( - _DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context) - ) + post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)) stages.append(_DepthDecodeCudaGraphLayerStage(*output)) last_stage = stages[-1] @@ -502,11 +481,7 @@ class DepthDecodeCudaGraphManager: ), device, ) - post_graphs.append( - _DepthDecodeCudaGraphPostStage( - graph=last_graph, attn_context=last_attn_context - ) - ) + post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context)) return _DepthDecodeCudaGraph( cache_key=self._depth_decode_key(next_input_ids, attention_bias), pre_graph=pre_graph, @@ -537,9 +512,7 @@ class DepthDecodeCudaGraphManager: self.graph = decode_graph else: decode_graph.token_ids.copy_(next_input_ids) - self._select_depth_decode_rope( - decode_graph.cos, decode_graph.sin, past_length=past_length - ) + self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length) return decode_graph def _run_depth_decode_attention_core( @@ -628,9 +601,7 @@ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]: sig(context.cross_mask), sig(context.self_mask), sig(context.valid_action), - None - if context.rope_cache is None - else tuple(sig(t) for t in context.rope_cache), + None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache), ) @@ -639,10 +610,7 @@ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, . return tuple( ( sig(step.conditioning), - tuple( - tuple(sig(t) for t in block_modulation) - for block_modulation in step.block_modulations - ), + tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations), tuple(sig(t) for t in step.final_modulation), ) for step in modulations @@ -678,10 +646,7 @@ def _clone_static_context(context: Any) -> Any: if context.rope_cache is not None: rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache) return context.__class__( - kv_contexts=tuple( - (_clone_static_tensor(k), _clone_static_tensor(v)) - for k, v in context.kv_contexts - ), + kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts), cross_mask=_clone_static_tensor(context.cross_mask), self_mask=_clone_static_tensor(context.self_mask), valid_action=_clone_static_tensor(context.valid_action), @@ -697,9 +662,7 @@ def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]: tuple(_clone_static_tensor(t) for t in block_modulation) for block_modulation in step.block_modulations ), - final_modulation=tuple( - _clone_static_tensor(t) for t in step.final_modulation - ), + final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation), ) for step in modulations ) @@ -760,9 +723,7 @@ def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) diff --git a/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py index 4dde1dcd3..e0e026c4f 100644 --- a/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py @@ -105,9 +105,7 @@ _DEPTH_REASONING_PATCH_SIZE = 32 _DEPTH_REASONING_THRESHOLD = 0.996 -def _modulate( - x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor -) -> torch.Tensor: +def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -195,11 +193,7 @@ class ActionExpertRotaryEmbedding(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: half_dim = self.head_dim // 2 inv_freq = 1.0 / ( - self.base - ** ( - torch.arange(0, half_dim, device=device, dtype=torch.float32) - / max(half_dim, 1) - ) + self.base ** (torch.arange(0, half_dim, device=device, dtype=torch.float32) / max(half_dim, 1)) ) positions = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(positions, inv_freq) @@ -215,9 +209,7 @@ class ActionExpertRotaryEmbedding(nn.Module): rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if rope_cache is None: - rope_cache = self.build_cache( - seq_len=q.shape[-2], device=q.device, dtype=q.dtype - ) + rope_cache = self.build_cache(seq_len=q.shape[-2], device=q.device, dtype=q.dtype) cos, sin = rope_cache half_dim = self.head_dim // 2 @@ -247,20 +239,14 @@ class ActionExpertSelfAttention(nn.Module): self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.attn_dropout = attn_dropout - self.q_norm = ( - ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None - ) - self.k_norm = ( - ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None - ) + self.q_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None + self.k_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None self.rope = ActionExpertRotaryEmbedding(self.head_dim) if use_rope else None self.qkv = nn.Linear(hidden_size, hidden_size * 3) self.out_proj = nn.Linear(hidden_size, hidden_size) self.out_drop = nn.Dropout(proj_dropout) - def _apply_qk_norm( - self, q: torch.Tensor, k: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.q_norm is None or self.k_norm is None: return q, k return self.q_norm(q), self.k_norm(k) @@ -326,12 +312,8 @@ class ActionExpertCrossAttention(nn.Module): self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.attn_dropout = attn_dropout - self.q_norm = ( - ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None - ) - self.k_norm = ( - ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None - ) + self.q_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None + self.k_norm = ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None self.q_proj = nn.Linear(hidden_size, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) self.out_drop = nn.Dropout(proj_dropout) @@ -503,9 +485,7 @@ class ActionExpertBlock(nn.Module): kv_v=cross_kv[1], attn_mask=attn_mask, ) - x = x + gate_mlp.unsqueeze(1) * self.mlp( - _modulate(self.ff_norm(x), shift_mlp, scale_mlp) - ) + x = x + gate_mlp.unsqueeze(1) * self.mlp(_modulate(self.ff_norm(x), shift_mlp, scale_mlp)) return x @@ -579,24 +559,14 @@ class ActionExpert(nn.Module): nn.SiLU(), nn.Linear(config.hidden_size, config.hidden_size, device=device), ) - self.action_embed = nn.Linear( - config.max_action_dim, config.hidden_size, device=device - ) - self.context_k_proj = nn.Linear( - self.llm_kv_dim, config.hidden_size, bias=False, device=device - ) - self.context_v_proj = nn.Linear( - self.llm_kv_dim, config.hidden_size, bias=False, device=device - ) + self.action_embed = nn.Linear(config.max_action_dim, config.hidden_size, device=device) + self.context_k_proj = nn.Linear(self.llm_kv_dim, config.hidden_size, bias=False, device=device) + self.context_v_proj = nn.Linear(self.llm_kv_dim, config.hidden_size, bias=False, device=device) self.context_norm = ( - ActionExpertRMSNorm(config.hidden_size, eps=1e-6) - if config.context_layer_norm - else nn.Identity() + ActionExpertRMSNorm(config.hidden_size, eps=1e-6) if config.context_layer_norm else nn.Identity() ) self._modulation_cache_key: Optional[Tuple[Any, ...]] = None - self._modulation_cache_value: Optional[Sequence[ActionExpertStepModulation]] = ( - None - ) + self._modulation_cache_value: Optional[Sequence[ActionExpertStepModulation]] = None self.blocks = nn.ModuleList( [ ActionExpertBlock( @@ -613,9 +583,7 @@ class ActionExpert(nn.Module): for _ in range(config.num_layers) ] ) - self.final_layer = ActionExpertFinalLayer( - config.hidden_size, config.max_action_dim - ) + self.final_layer = ActionExpertFinalLayer(config.hidden_size, config.max_action_dim) self.reset_parameters() def reset_parameters(self) -> None: @@ -653,9 +621,7 @@ class ActionExpert(nn.Module): _init_linear(self.final_layer.linear, zero=True) def _reshape_hidden_to_heads(self, x: torch.Tensor) -> torch.Tensor: - return x.view( - x.shape[0], x.shape[1], self.config.num_heads, self.action_head_dim - ) + return x.view(x.shape[0], x.shape[1], self.config.num_heads, self.action_head_dim) def _time_conditioning(self, timesteps: torch.Tensor) -> torch.Tensor: conditioning = self.time_embed[0](timesteps) @@ -713,13 +679,8 @@ class ActionExpert(nn.Module): key_mask = (~valid)[:, None, None, :].to(dtype=dtype) mask = key_mask * torch.finfo(dtype).min if self.config.causal_attn: - causal = torch.ones(seq_len, seq_len, device=device, dtype=torch.bool).triu( - diagonal=1 - ) - causal = ( - causal.unsqueeze(0).unsqueeze(0).to(dtype=dtype) - * torch.finfo(dtype).min - ) + causal = torch.ones(seq_len, seq_len, device=device, dtype=torch.bool).triu(diagonal=1) + causal = causal.unsqueeze(0).unsqueeze(0).to(dtype=dtype) * torch.finfo(dtype).min mask = causal if mask is None else mask + causal return mask @@ -742,9 +703,7 @@ class ActionExpert(nn.Module): ) valid_action = None if action_attention_mask is not None: - valid_action = action_attention_mask.to( - device=device, dtype=dtype - ).unsqueeze(-1) + valid_action = action_attention_mask.to(device=device, dtype=dtype).unsqueeze(-1) rope_cache = None if len(self.blocks) > 0 and self.blocks[0].self_attn.rope is not None: rope_cache = self.blocks[0].self_attn.rope.build_cache( @@ -758,9 +717,7 @@ class ActionExpert(nn.Module): batch_size, dtype, ) - self_mask = self._build_self_attention_mask( - action_attention_mask, seq_len, device, dtype - ) + self_mask = self._build_self_attention_mask(action_attention_mask, seq_len, device, dtype) return ActionExpertContext( kv_contexts=kv_contexts, cross_mask=cross_mask, @@ -778,12 +735,8 @@ class ActionExpert(nn.Module): conditioning = self._time_conditioning(step_t) block_modulations = [] for block in self.blocks: - block_modulations.append( - tuple(block.modulation(conditioning).chunk(9, dim=1)) - ) - final_modulation = tuple( - self.final_layer.modulation(conditioning).chunk(2, dim=1) - ) + block_modulations.append(tuple(block.modulation(conditioning).chunk(9, dim=1))) + final_modulation = tuple(self.final_layer.modulation(conditioning).chunk(2, dim=1)) cache.append( ActionExpertStepModulation( conditioning=conditioning, @@ -801,10 +754,7 @@ class ActionExpert(nn.Module): ) -> Sequence[ActionExpertStepModulation]: if self.training or cache_key is None: return self.prepare_modulation_cache(timesteps) - if ( - self._modulation_cache_key == cache_key - and self._modulation_cache_value is not None - ): + if self._modulation_cache_key == cache_key and self._modulation_cache_value is not None: return self._modulation_cache_value cached = self.prepare_modulation_cache(timesteps) self._modulation_cache_key = cache_key @@ -826,9 +776,7 @@ class ActionExpert(nn.Module): ) if modulation is None: conditioning = self._time_conditioning(timesteps) - block_modulations: Sequence[Optional[Tuple[torch.Tensor, ...]]] = [ - None - ] * len(self.blocks) + block_modulations: Sequence[Optional[Tuple[torch.Tensor, ...]]] = [None] * len(self.blocks) final_modulation = None else: conditioning = modulation.conditioning @@ -960,9 +908,7 @@ class _FeatureNormalizer: self.zero_mask = zero_mask @classmethod - def from_stats( - cls, stats: Optional[Mapping[str, Any]], mode: str - ) -> Optional["_FeatureNormalizer"]: + def from_stats(cls, stats: Optional[Mapping[str, Any]], mode: str) -> Optional["_FeatureNormalizer"]: if stats is None: return None raw_mask = stats.get("mask") if isinstance(stats, Mapping) else None @@ -1006,15 +952,11 @@ class _FeatureNormalizer: q_low = _to_array(stats.get(low_key)) q_high = _to_array(stats.get(high_key)) if q_low is None or q_high is None: - raise ValueError( - f"norm_mode={mode!r} requires {low_key} and {high_key} stats." - ) + raise ValueError(f"norm_mode={mode!r} requires {low_key} and {high_key} stats.") min_val = _to_array(stats.get("min")) max_val = _to_array(stats.get("max")) fallback = min_val if min_val is not None else q_low - zero_mask = ( - None if min_val is None or max_val is None else (min_val == max_val) - ) + zero_mask = None if min_val is None or max_val is None else (min_val == max_val) return cls( mode=mode, min_val=min_val, @@ -1036,17 +978,9 @@ class _FeatureNormalizer: elif self.mode == "mean_std": normed = (arr - self.mean) / np.maximum(self.std, eps) elif self.mode == "min_max": - normed = ( - 2.0 - * (arr - self.min_val) - / np.maximum(self.max_val - self.min_val, eps) - - 1.0 - ) + normed = 2.0 * (arr - self.min_val) / np.maximum(self.max_val - self.min_val, eps) - 1.0 elif self.mode in {"q01_q99", "q10_q90"}: - normed = ( - 2.0 * (arr - self.q_low) / np.maximum(self.q_high - self.q_low, eps) - - 1.0 - ) + normed = 2.0 * (arr - self.q_low) / np.maximum(self.q_high - self.q_low, eps) - 1.0 else: normed = arr if self.mode in {"min_max", "q01_q99", "q10_q90"}: @@ -1109,9 +1043,7 @@ class _RobotStats: raise ValueError("MolmoAct2 `predict_action` requires `norm_tag`.") if tag not in self.metadata_by_tag: allowed = ", ".join(sorted(self.metadata_by_tag)) - raise ValueError( - f"Unknown MolmoAct2 normalization tag {tag!r}. Allowed tags: {allowed}." - ) + raise ValueError(f"Unknown MolmoAct2 normalization tag {tag!r}. Allowed tags: {allowed}.") return tag def get_metadata(self, norm_tag: Optional[str]) -> Dict[str, Any]: @@ -1149,9 +1081,7 @@ class _RobotStats: return None value = int(value) if value < 1: - raise ValueError( - f"Robot metadata for norm_tag={norm_tag!r} must define {key} >= 1." - ) + raise ValueError(f"Robot metadata for norm_tag={norm_tag!r} must define {key} >= 1.") return value @@ -1206,12 +1136,8 @@ def _compute_depth_update_mask( f"enable_adaptive_depth=True requires a square depth grid, got num_depth_codes={int(num_depth_codes)}." ) target_size = grid_side * _DEPTH_REASONING_PATCH_SIZE - current_resized = _resize_depth_reasoning_image(current_image, target_size).astype( - np.float32 - ) - previous_resized = _resize_depth_reasoning_image( - previous_image, target_size - ).astype(np.float32) + current_resized = _resize_depth_reasoning_image(current_image, target_size).astype(np.float32) + previous_resized = _resize_depth_reasoning_image(previous_image, target_size).astype(np.float32) current_patches = ( current_resized.reshape( grid_side, @@ -1239,9 +1165,7 @@ def _compute_depth_update_mask( norm_previous = np.linalg.norm(previous_patches, axis=-1) denom = norm_current * norm_previous similarity = np.where(denom < 1e-8, 1.0, dot / (denom + 1e-12)) - return np.asarray(similarity < _DEPTH_REASONING_THRESHOLD, dtype=np.bool_).reshape( - -1 - ) + return np.asarray(similarity < _DEPTH_REASONING_THRESHOLD, dtype=np.bool_).reshape(-1) def _build_depth_update_spans( @@ -1266,9 +1190,7 @@ def _build_depth_update_spans( def _wrap_setup_text(setup_type: str, add_setup_tokens: bool = False) -> str: setup_type = str(setup_type or "") - if setup_type.startswith(SETUP_START_TOKEN) and setup_type.endswith( - SETUP_END_TOKEN - ): + if setup_type.startswith(SETUP_START_TOKEN) and setup_type.endswith(SETUP_END_TOKEN): return setup_type if not setup_type or not add_setup_tokens: return setup_type @@ -1277,18 +1199,14 @@ def _wrap_setup_text(setup_type: str, add_setup_tokens: bool = False) -> str: def _wrap_control_text(control_mode: str, add_control_tokens: bool = False) -> str: control_mode = str(control_mode or "") - if control_mode.startswith(CONTROL_START_TOKEN) and control_mode.endswith( - CONTROL_END_TOKEN - ): + if control_mode.startswith(CONTROL_START_TOKEN) and control_mode.endswith(CONTROL_END_TOKEN): return control_mode if not control_mode or not add_control_tokens: return control_mode return f"{CONTROL_START_TOKEN}{control_mode}{CONTROL_END_TOKEN}" -def _discretize_normalized_state( - state: np.ndarray, num_state_tokens: int -) -> np.ndarray: +def _discretize_normalized_state(state: np.ndarray, num_state_tokens: int) -> np.ndarray: arr = np.asarray(state, dtype=np.float32) arr = np.nan_to_num(arr, nan=0.0, posinf=1.0, neginf=-1.0) arr = np.clip(arr, -1.0, 1.0) @@ -1296,9 +1214,7 @@ def _discretize_normalized_state( return np.clip(np.rint(scaled).astype(np.int64), 0, int(num_state_tokens) - 1) -def _build_discrete_state_string( - state: Optional[np.ndarray], num_state_tokens: int -) -> str: +def _build_discrete_state_string(state: Optional[np.ndarray], num_state_tokens: int) -> str: if state is None: return "" token_ids = _discretize_normalized_state(state, num_state_tokens).reshape(-1) @@ -1318,9 +1234,7 @@ def _normalize_question_text(text: str) -> str: normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() normalized = normalized.rstrip(_QUESTION_TRAILING_CLOSERS).rstrip() normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() - sentence_chunks = [ - chunk.strip() for chunk in re.split(r"[.!?]+", normalized) if chunk.strip() - ] + sentence_chunks = [chunk.strip() for chunk in re.split(r"[.!?]+", normalized) if chunk.strip()] if len(sentence_chunks) > 1: normalized = "; ".join(sentence_chunks) normalized = normalized.lower() @@ -1339,13 +1253,9 @@ def _build_robot_text( num_images: int, ) -> str: setup_text = _wrap_setup_text(setup_type, add_setup_tokens=add_setup_tokens) - control_text = _wrap_control_text( - control_mode, add_control_tokens=add_control_tokens - ) + control_text = _wrap_control_text(control_mode, add_control_tokens=add_control_tokens) state_clause = ( - f" The current state of the robot is {discrete_state_string}." - if discrete_state_string - else "" + f" The current state of the robot is {discrete_state_string}." if discrete_state_string else "" ) if style == "robot_depth_action": prompt = ( @@ -1376,9 +1286,7 @@ def _flatten_generated_token_ids(token_ids: torch.Tensor) -> List[int]: return [int(x) for x in token_ids[0].detach().cpu().tolist()] if token_ids.ndim == 1: return [int(x) for x in token_ids.detach().cpu().tolist()] - raise ValueError( - f"Unexpected generated token tensor shape {tuple(token_ids.shape)}" - ) + raise ValueError(f"Unexpected generated token tensor shape {tuple(token_ids.shape)}") def _extract_discrete_token_bins( @@ -1549,9 +1457,7 @@ class ViTMultiHeadDotProductAttention(nn.Module): ] def _split_heads(self, hidden_states, num_heads) -> torch.Tensor: - return hidden_states.reshape( - hidden_states.shape[:2] + (num_heads, self.head_dim) - ) + return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) def _merge_heads(self, hidden_states) -> torch.Tensor: return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) @@ -1577,12 +1483,8 @@ class ViTMultiHeadDotProductAttention(nn.Module): xv = self._split_heads(xv, self.num_key_value_heads) if self.num_heads != self.num_key_value_heads: - xk = xk.repeat_interleave( - self.num_key_value_groups, dim=2, output_size=self.num_heads - ) - xv = xv.repeat_interleave( - self.num_key_value_groups, dim=2, output_size=self.num_heads - ) + xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) + xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) og_dtype = xq.dtype @@ -1593,16 +1495,10 @@ class ViTMultiHeadDotProductAttention(nn.Module): dropout_p = 0.0 if not self.training else self.attention_dropout if self.attn_implementation == "eager": - attn_weights = torch.einsum( - "...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - xq.dtype - ) + attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype) attn_weights = F.dropout(attn_weights, p=dropout_p, training=self.training) - attn_output = torch.einsum( - "...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv - ) + attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv) elif self.attn_implementation == "sdpa": if self.float32_attention: @@ -1651,9 +1547,7 @@ class ViTMultiHeadDotProductAttention(nn.Module): implementation=self.attn_implementation, ) else: - raise ValueError( - f"Attention implementation {self.attn_implementation} not supported" - ) + raise ValueError(f"Attention implementation {self.attn_implementation} not supported") attn_output = attn_output.to(og_dtype) attn_output = self._merge_heads(attn_output) @@ -1664,9 +1558,7 @@ class ViTMultiHeadDotProductAttention(nn.Module): class MolmoAct2VisionBlock(nn.Module): - def __init__( - self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None - ): + def __init__(self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None): super().__init__() self.attention = ViTMultiHeadDotProductAttention( hidden_size=config.hidden_size, @@ -1685,12 +1577,8 @@ class MolmoAct2VisionBlock(nn.Module): config.hidden_act, device=device, ) - self.attention_norm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps, device=device - ) - self.ffn_norm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps, device=device - ) + self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) + self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attention(self.attention_norm(x)) @@ -1699,16 +1587,11 @@ class MolmoAct2VisionBlock(nn.Module): class MolmoAct2VisionBlockCollection(nn.Module): - def __init__( - self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None - ): + def __init__(self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None): super().__init__() self.conifg = config self.resblocks = nn.ModuleList( - [ - MolmoAct2VisionBlock(config, device) - for _ in range(config.num_hidden_layers) - ] + [MolmoAct2VisionBlock(config, device) for _ in range(config.num_hidden_layers)] ) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: @@ -1720,9 +1603,7 @@ class MolmoAct2VisionBlockCollection(nn.Module): class MolmoAct2VisionTransformer(nn.Module): - def __init__( - self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None - ): + def __init__(self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None): super().__init__() self.config = config @@ -1811,9 +1692,7 @@ class ImageProjectorMLP(nn.Module): class MolmoAct2VisionBackbone(nn.Module): - def __init__( - self, vit_config: MolmoAct2VitConfig, adapter_config: MolmoAct2AdapterConfig - ): + def __init__(self, vit_config: MolmoAct2VitConfig, adapter_config: MolmoAct2AdapterConfig): super().__init__() self.vit_config = vit_config self.adapter_config = adapter_config @@ -1879,7 +1758,9 @@ class MolmoAct2VisionBackbone(nn.Module): missing = needed_layers - set(selected_features) if missing: - raise RuntimeError(f"MolmoAct2 vision backbone did not produce requested layers: {sorted(missing)}.") + raise RuntimeError( + f"MolmoAct2 vision backbone did not produce requested layers: {sorted(missing)}." + ) image_features = torch.cat([selected_features[int(layer)] for layer in self.vit_layers], dim=-1) @@ -1934,31 +1815,23 @@ class MolmoAct2VisionBackbone(nn.Module): ) # Now [batch, num_high_res_features, pool_dim, dim] - to_pool = image_features.reshape(batch_size, -1, dim)[ - batch_idx, torch.clip(pooled_patches_idx, 0) - ] + to_pool = image_features.reshape(batch_size, -1, dim)[batch_idx, torch.clip(pooled_patches_idx, 0)] to_pool = to_pool * valid.to(self.dtype)[:, :, :, None] to_pool = to_pool.reshape([-1, pooled_patches_idx.shape[-1], dim]) if self.adapter_config.pooling_attention_mask: attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]]) denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1) denom = torch.where(denom == 0, 1, denom) - query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to( - to_pool.dtype - ) + query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(to_pool.dtype) else: attn_mask = None query = to_pool.mean(-2, keepdim=True) pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask) - pooled_features = pooled_features.reshape( - [batch_size, -1, pooled_features.shape[-1]] - ) + pooled_features = pooled_features.reshape([batch_size, -1, pooled_features.shape[-1]]) # MLP layer to map the feature. pooled_features = self.image_projector(pooled_features) - return pooled_features.view(-1, pooled_features.shape[-1])[ - valid_token.flatten() - ] + return pooled_features.view(-1, pooled_features.shape[-1])[valid_token.flatten()] # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -2013,9 +1886,7 @@ class MolmoAct2RotaryEmbedding(nn.Module): self.rope_type = rope_type elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): # BC: "rope_type" was originally "type" - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type") - ) + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings @@ -2039,16 +1910,11 @@ class MolmoAct2RotaryEmbedding(nn.Module): ) -> tuple[torch.Tensor, float]: inv_freq = 1.0 / ( config.rope_theta - ** ( - torch.arange(0, config.head_dim, 2, dtype=torch.float32, device=device) - / config.head_dim - ) + ** (torch.arange(0, config.head_dim, 2, dtype=torch.float32, device=device) / config.head_dim) ) return inv_freq, 1.0 - def _target_cache_seq_len( - self, x: torch.Tensor, position_ids: Optional[torch.Tensor] - ) -> int: + def _target_cache_seq_len(self, x: torch.Tensor, position_ids: Optional[torch.Tensor]) -> int: if self.config.max_position_embeddings: return int(self.config.max_position_embeddings) if position_ids is not None: @@ -2079,9 +1945,7 @@ class MolmoAct2RotaryEmbedding(nn.Module): needs_refresh = ( not bool(torch.isfinite(inv_freq_cpu).all().item()) or bool((inv_freq_cpu <= 0).any().item()) - or not bool( - torch.isclose(inv_freq_cpu[0].cpu(), torch.tensor(1.0)).item() - ) + or not bool(torch.isclose(inv_freq_cpu[0].cpu(), torch.tensor(1.0)).item()) ) if needs_refresh: inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) @@ -2094,9 +1958,7 @@ class MolmoAct2RotaryEmbedding(nn.Module): device_type = device.type if device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): seq = torch.arange(seq_len, device=device, dtype=torch.float) - freqs = torch.einsum( - "i,j->ij", seq, self.inv_freq.to(device=device, dtype=torch.float) - ) + freqs = torch.einsum("i,j->ij", seq, self.inv_freq.to(device=device, dtype=torch.float)) emb = torch.cat((freqs, freqs), dim=-1) self._pos_sin_cache = emb.sin()[None, None, :, :] * self.attention_scaling self._pos_cos_cache = emb.cos()[None, None, :, :] * self.attention_scaling @@ -2113,9 +1975,7 @@ class MolmoAct2RotaryEmbedding(nn.Module): device = torch.device(device) seq_len = int(max_seq_len or self.config.max_position_embeddings or 0) if seq_len <= 0: - raise ValueError( - "RoPE cache preparation requires a positive max sequence length." - ) + raise ValueError("RoPE cache preparation requires a positive max sequence length.") if self._rope_cache_ready(device, seq_len): return self._refresh_inv_freq_if_needed(device) @@ -2133,19 +1993,13 @@ class MolmoAct2RotaryEmbedding(nn.Module): sin = pos_sin[0, 0, : x.shape[-2], :] cos = pos_cos[0, 0, : x.shape[-2], :] else: - sin = pos_sin[0, 0][position_ids].view( - position_ids.shape + (pos_sin.shape[-1],) - ) - cos = pos_cos[0, 0][position_ids].view( - position_ids.shape + (pos_cos.shape[-1],) - ) + sin = pos_sin[0, 0][position_ids].view(position_ids.shape + (pos_sin.shape[-1],)) + cos = pos_cos[0, 0][position_ids].view(position_ids.shape + (pos_cos.shape[-1],)) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward( - self, x, position_ids: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, x, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: seq_len = self._target_cache_seq_len(x, position_ids) if not self._rope_cache_ready(x.device, seq_len): self._refresh_inv_freq_if_needed(x.device) @@ -2187,9 +2041,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -2211,12 +2063,8 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query.dtype - ) - attn_weights = nn.functional.dropout( - attn_weights, p=dropout, training=module.training - ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -2232,9 +2080,7 @@ class MolmoAct2Attention(nn.Module): self.layer_idx = layer_idx self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = ( - config.num_attention_heads // config.num_key_value_heads - ) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.head_dim = config.head_dim self.scaling = self.head_dim**-0.5 self.is_causal = True @@ -2295,21 +2141,13 @@ class MolmoAct2Attention(nn.Module): value_states = value_states.view(hidden_shape) # Optionally apply layer norm to keys and queries. - if ( - self.q_norm is not None - and self.k_norm is not None - and self.qk_norm_type != "qwen3" - ): + if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type != "qwen3": query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) query_states = query_states.view(hidden_shape) key_states = key_states.view(hidden_shape) - if ( - self.q_norm is not None - and self.k_norm is not None - and self.qk_norm_type == "qwen3" - ): + if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type == "qwen3": query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) query_states = query_states.transpose(1, 2) @@ -2317,9 +2155,7 @@ class MolmoAct2Attention(nn.Module): value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -2350,9 +2186,7 @@ class MolmoAct2Attention(nn.Module): else: attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[ - self.config._attn_implementation - ] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -2381,9 +2215,7 @@ class LanguageModelMLP(nn.Module): device: Union[str, torch.device] = None, ): super().__init__() - self.ff_proj = nn.Linear( - input_dim, intermediate_size * 2, bias=False, device=device - ) + self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device) self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device) self.act = ACT2FN[hidden_act] @@ -2406,9 +2238,7 @@ class MolmoAct2DecoderLayer(GradientCheckpointingLayer): self.config = config self.self_attn = MolmoAct2Attention(config, layer_idx) - self.attn_norm = MolmoAct2RMSNorm( - config.hidden_size, eps=config.layer_norm_eps, device=device - ) + self.attn_norm = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) self.dropout = nn.Dropout(config.residual_dropout) self.mlp = LanguageModelMLP( config.hidden_size, @@ -2416,9 +2246,7 @@ class MolmoAct2DecoderLayer(GradientCheckpointingLayer): config.hidden_act, device=device, ) - self.ff_norm = MolmoAct2RMSNorm( - config.hidden_size, eps=config.layer_norm_eps, device=device - ) + self.ff_norm = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) def forward( self, @@ -2431,9 +2259,7 @@ class MolmoAct2DecoderLayer(GradientCheckpointingLayer): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[ - torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) residual = hidden_states @@ -2486,9 +2312,7 @@ class MolmoAct2PostNormDecoderLayer(MolmoAct2DecoderLayer): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> tuple[ - torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) residual = hidden_states @@ -2606,16 +2430,9 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): else: self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.emb_drop = nn.Dropout(config.embedding_dropout) - decoder_layer = ( - MolmoAct2PostNormDecoderLayer - if config.norm_after - else MolmoAct2DecoderLayer - ) + decoder_layer = MolmoAct2PostNormDecoderLayer if config.norm_after else MolmoAct2DecoderLayer self.blocks = nn.ModuleList( - [ - decoder_layer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] + [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.ln_f = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps) if config.rope_scaling_layers is not None: @@ -2666,14 +2483,10 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) @@ -2683,9 +2496,7 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): use_cache = False if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( @@ -2702,9 +2513,7 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): past_key_values = DynamicCache(config=self.config) if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -2747,9 +2556,7 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): all_self_attns = () if output_attentions else None collected_kv_states = [] if collect_layer_kv_states else None - for layer_idx, decoder_block in enumerate( - self.blocks[: self.config.num_hidden_layers] - ): + for layer_idx, decoder_block in enumerate(self.blocks[: self.config.num_hidden_layers]): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -2816,13 +2623,9 @@ def token_type_ids_mask_function( # Since vmap doesn't support `if statement` we workaround it with `torch.where` safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx] - token_type_ids_at_kv_idx = torch.where( - kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0 - ) + token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0) - is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & ( - token_type_ids_at_kv_idx == 1 - ) + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1) # This is bidirectional attention whenever we are dealing with image tokens return is_image_block & is_image_block @@ -2842,12 +2645,8 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): self.transformer: MolmoAct2TextModel = MolmoAct2TextModel(config.text_config) self.vision_backbone: Optional[MolmoAct2VisionBackbone] = None if config.vit_config is not None and config.adapter_config is not None: - self.vision_backbone = MolmoAct2VisionBackbone( - config.vit_config, config.adapter_config - ) - llm_kv_dim = ( - config.text_config.num_key_value_heads * config.text_config.head_dim - ) + self.vision_backbone = MolmoAct2VisionBackbone(config.vit_config, config.adapter_config) + llm_kv_dim = config.text_config.num_key_value_heads * config.text_config.head_dim if config.add_action_expert: self.action_expert = ActionExpert( config.action_expert_config, @@ -2860,8 +2659,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): if config.add_action_expert and config.action_expert_depth_gate: if config.action_expert_depth_gate_per_layer: self.action_expert_depth_gate = nn.ModuleList( - nn.Linear(llm_kv_dim, 1) - for _ in range(config.action_expert_config.num_layers) + nn.Linear(llm_kv_dim, 1) for _ in range(config.action_expert_config.num_layers) ) else: self.action_expert_depth_gate = nn.Linear(llm_kv_dim, 1) @@ -2900,9 +2698,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): ) for gate in gates: nn.init.zeros_(gate.weight) - nn.init.constant_( - gate.bias, float(self.config.action_expert_depth_gate_init_bias) - ) + nn.init.constant_(gate.bias, float(self.config.action_expert_depth_gate_init_bias)) def _resolve_depth_gate_token_ids(self) -> Tuple[int, ...]: if not self.config.action_expert_depth_gate: @@ -2915,26 +2711,19 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): ): if token_id is not None: token_ids.append(int(token_id)) - if ( - self.config.depth_token_start_id is not None - and int(self.config.num_depth_tokens or 0) > 0 - ): + if self.config.depth_token_start_id is not None and int(self.config.num_depth_tokens or 0) > 0: start = int(self.config.depth_token_start_id) token_ids.extend(range(start, start + int(self.config.num_depth_tokens))) return tuple(dict.fromkeys(token_ids)) def _require_action_expert(self) -> ActionExpert: if self.action_expert is None: - raise RuntimeError( - "This MolmoAct2 checkpoint does not include an action expert." - ) + raise RuntimeError("This MolmoAct2 checkpoint does not include an action expert.") return self.action_expert def _cache_to_sequence(self, cache: torch.Tensor) -> torch.Tensor: if cache.dim() != 4: - raise ValueError( - f"Expected KV cache tensor with 4 dims, got shape {tuple(cache.shape)}" - ) + raise ValueError(f"Expected KV cache tensor with 4 dims, got shape {tuple(cache.shape)}") head_candidates = { self.config.text_config.num_key_value_heads, self.config.text_config.num_attention_heads, @@ -2951,13 +2740,9 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): bsz, seq_len, n_heads, head_dim = cache.shape return cache.reshape(bsz, seq_len, n_heads * head_dim) - def _extract_kv_states( - self, past_key_values: Cache - ) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]: + def _extract_kv_states(self, past_key_values: Cache) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]: if past_key_values is None: - raise RuntimeError( - "Action generation requires past_key_values from the VLM forward pass." - ) + raise RuntimeError("Action generation requires past_key_values from the VLM forward pass.") seq_len = _cache_seq_len_int(past_key_values) kv_states = [] for key, value in _iter_cache_key_values(past_key_values): @@ -2966,9 +2751,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): if key.shape[-2] > seq_len: key = key[..., :seq_len, :] value = value[..., :seq_len, :] - kv_states.append( - (self._cache_to_sequence(key), self._cache_to_sequence(value)) - ) + kv_states.append((self._cache_to_sequence(key), self._cache_to_sequence(value))) if len(kv_states) != self.config.action_expert_config.num_layers: raise RuntimeError( f"Expected {self.config.action_expert_config.num_layers} KV layers, got {len(kv_states)}." @@ -2984,9 +2767,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): ) -> None: if start_id is None or end_id is None: return - start_positions = ( - (row_ids == start_id).nonzero(as_tuple=False).flatten().tolist() - ) + start_positions = (row_ids == start_id).nonzero(as_tuple=False).flatten().tolist() if not start_positions: return end_positions = (row_ids == end_id).nonzero(as_tuple=False).flatten().tolist() @@ -3031,11 +2812,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): input_ids: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: - if ( - not self.config.action_expert_depth_gate - or input_ids is None - or not self._depth_gate_token_ids - ): + if not self.config.action_expert_depth_gate or input_ids is None or not self._depth_gate_token_ids: return None depth_token_ids = torch.as_tensor( self._depth_gate_token_ids, @@ -3044,9 +2821,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): ) depth_mask = (input_ids.unsqueeze(-1) == depth_token_ids).any(dim=-1) if encoder_attention_mask is not None: - depth_mask = depth_mask & encoder_attention_mask.to( - device=input_ids.device, dtype=torch.bool - ) + depth_mask = depth_mask & encoder_attention_mask.to(device=input_ids.device, dtype=torch.bool) return depth_mask @staticmethod @@ -3060,17 +2835,11 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): if source.ndim == 4: source = source.reshape(source.shape[0], source.shape[1], -1) if source.ndim != 3: - raise ValueError( - f"Depth gate expected a 3D sequence tensor, got {tuple(source.shape)}." - ) + raise ValueError(f"Depth gate expected a 3D sequence tensor, got {tuple(source.shape)}.") if encoder_attention_mask is not None: - valid_mask = encoder_attention_mask.to( - device=source.device, dtype=torch.bool - ) + valid_mask = encoder_attention_mask.to(device=source.device, dtype=torch.bool) else: - valid_mask = torch.ones( - depth_mask.shape, device=source.device, dtype=torch.bool - ) + valid_mask = torch.ones(depth_mask.shape, device=source.device, dtype=torch.bool) depth_mask = depth_mask.to(device=source.device, dtype=torch.bool) pool_mask = valid_mask & ~depth_mask has_pool = pool_mask.any(dim=-1, keepdim=True) @@ -3086,9 +2855,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): input_ids: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.Tensor], layer_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]], - ) -> Tuple[ - Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], Optional[torch.Tensor] - ]: + ) -> Tuple[Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], Optional[torch.Tensor]]: gate_head = self.action_expert_depth_gate if gate_head is None: return None, None @@ -3129,9 +2896,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): if isinstance(gate, torch.Tensor): return gate if len(gate) != num_layers: - raise ValueError( - f"Depth gate layer count mismatch: gates={len(gate)}, layers={num_layers}." - ) + raise ValueError(f"Depth gate layer count mismatch: gates={len(gate)}, layers={num_layers}.") return gate[layer_idx] def _apply_depth_gate_to_layer_kv_states( @@ -3144,9 +2909,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): return layer_kv_states gated_kv = [] for layer_idx, (key, value) in enumerate(layer_kv_states): - layer_gate = self._depth_gate_for_layer( - gate, layer_idx, num_layers=len(layer_kv_states) - ) + layer_gate = self._depth_gate_for_layer(gate, layer_idx, num_layers=len(layer_kv_states)) mask = depth_mask.to(device=key.device, dtype=torch.bool) view_shape = [mask.shape[0], mask.shape[1]] + [1] * (key.ndim - 2) scale = torch.ones(view_shape, device=key.device, dtype=key.dtype) @@ -3197,9 +2960,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): return tensor return tensor.masked_fill(~valid_mask, 0) - def _run_action_flow_loop( - self, inputs: _ActionFlowInputs, steps: int - ) -> torch.Tensor: + def _run_action_flow_loop(self, inputs: _ActionFlowInputs, steps: int) -> torch.Tensor: action_expert = self._require_action_expert() dt = 1.0 / steps trajectory = inputs.trajectory @@ -3274,13 +3035,9 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): use_cache=True, ) encoder_kv_states = self._extract_kv_states(outputs.past_key_values) - encoder_attention_mask = self._get_encoder_attention_mask( - input_ids, attention_mask - ) + encoder_attention_mask = self._get_encoder_attention_mask(input_ids, attention_mask) elif encoder_attention_mask is None: - encoder_attention_mask = self._get_encoder_attention_mask( - input_ids, attention_mask - ) + encoder_attention_mask = self._get_encoder_attention_mask(input_ids, attention_mask) depth_gate, depth_mask = self._depth_gate_from_condition( input_ids=input_ids, @@ -3321,10 +3078,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): dtype=trajectory.dtype, ) flow_timesteps = [ - torch.full( - (batch_size,), idx / steps, device=device, dtype=torch.float32 - ) - for idx in range(steps) + torch.full((batch_size,), idx / steps, device=device, dtype=torch.float32) for idx in range(steps) ] modulation_cache = action_expert.get_or_prepare_modulation_cache( flow_timesteps, @@ -3337,9 +3091,8 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): action_dim_is_pad=action_dim_is_pad, ) action_cuda_graph_manager = self.action_cuda_graph_manager - if ( - action_cuda_graph_manager is not None - and action_cuda_graph_manager.can_use_action_flow(flow_inputs) + if action_cuda_graph_manager is not None and action_cuda_graph_manager.can_use_action_flow( + flow_inputs ): trajectory = action_cuda_graph_manager.run_action_flow( flow_inputs, steps, self._run_action_flow_loop @@ -3398,15 +3151,11 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): # 2) Map each image index → example index # Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2] - example_ids_for_image = torch.arange(N, device=device).repeat_interleave( - counts - ) # [num_images] + example_ids_for_image = torch.arange(N, device=device).repeat_interleave(counts) # [num_images] assert example_ids_for_image.numel() == num_images # 2-1) Compute crops_per_example by summing per-image crop counts - crops_per_example = torch.zeros( - N, dtype=image_num_crops.dtype, device=image_num_crops.device - ) + crops_per_example = torch.zeros(N, dtype=image_num_crops.dtype, device=image_num_crops.device) crops_per_example.index_add_(0, example_ids_for_image, image_num_crops) # [N] # 2-2) Per-image number of patches = (crops per image) * n_patches @@ -3429,15 +3178,11 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): dtype=num_pooled_patches_per_image.dtype, device=num_pooled_patches_per_image.device, ) - num_pooled_patches_per_example.index_add_( - 0, example_ids_for_image, num_pooled_patches_per_image - ) + num_pooled_patches_per_example.index_add_(0, example_ids_for_image, num_pooled_patches_per_image) # Sanity checks total_crops = int(crops_per_example.sum().item()) - assert total_crops == n_crops, ( - f"Expected {total_crops} crops, but got {n_crops}" - ) + assert total_crops == n_crops, f"Expected {total_crops} crops, but got {n_crops}" total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item()) assert total_num_pooled_patches == image_token_pooling.size(0), ( @@ -3457,9 +3202,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): offset_crop = 0 for i in range(N): num = int(crops_per_example[i].item()) - cur = pixel_values[ - offset_crop : offset_crop + num - ] # [num, n_patches, pixels_per_patch] + cur = pixel_values[offset_crop : offset_crop + num] # [num, n_patches, pixels_per_patch] images[i, :num] = cur offset_crop += num @@ -3484,14 +3227,10 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): num_patches = int(num_pooled_patches_per_example[i].item()) # Subsequence of pooled tokens belonging to this example - cur = image_token_pooling[ - patch_offset : patch_offset + num_patches - ].clone() # [num_patches, dim] + cur = image_token_pooling[patch_offset : patch_offset + num_patches].clone() # [num_patches, dim] index_offset_per_example = index_offset_per_example_list[i] # length = c - per_img_pooled = num_pooled_patches_per_image[ - img_offset : img_offset + c - ] # [c] + per_img_pooled = num_pooled_patches_per_image[img_offset : img_offset + c] # [c] assert len(index_offset_per_example) == per_img_pooled.numel() @@ -3554,9 +3293,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): # 2) Map each video index -> example index # Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2] - example_ids_for_video = torch.arange(N, device=device).repeat_interleave( - counts - ) # [num_videos] + example_ids_for_video = torch.arange(N, device=device).repeat_interleave(counts) # [num_videos] assert example_ids_for_video.numel() == num_videos # 2-1) Compute frames_per_example by summing per-video frame counts @@ -3581,9 +3318,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): # Sanity checks total_frames = int(frames_per_example.sum().item()) - assert total_frames == n_frames, ( - f"Expected {total_frames} frames, but got {n_frames}" - ) + assert total_frames == n_frames, f"Expected {total_frames} frames, but got {n_frames}" total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item()) assert total_num_pooled_patches == video_token_pooling.size(0), ( @@ -3603,9 +3338,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): offset_frame = 0 for i in range(N): num = int(frames_per_example[i].item()) - cur = pixel_values_videos[ - offset_frame : offset_frame + num - ] # [num, n_patches, pixels_per_patch] + cur = pixel_values_videos[offset_frame : offset_frame + num] # [num, n_patches, pixels_per_patch] videos[i, :num] = cur offset_frame += num @@ -3626,9 +3359,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): patch_offset = 0 for i in range(N): num_patches = int(num_pooled_patches_per_example[i].item()) - cur = video_token_pooling[ - patch_offset : patch_offset + num_patches - ] # [num_patches, dim] + cur = video_token_pooling[patch_offset : patch_offset + num_patches] # [num_patches, dim] new_token_pooling[i, :num_patches] = cur patch_offset += num_patches @@ -3649,9 +3380,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): video_grids: Optional[torch.Tensor] = None, ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if pixel_values is not None and pixel_values_videos is not None: - raise ValueError( - "pixel_values and pixel_values_videos are provided at the same time" - ) + raise ValueError("pixel_values and pixel_values_videos are provided at the same time") elif pixel_values is not None: assert input_ids is not None images, token_pooling = self.build_batched_images( @@ -3724,9 +3453,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): valid_mask = positions.unsqueeze(0) < current_length valid_mask = valid_mask.expand(batch_size, -1) elif attention_mask.ndim == 2: - valid_mask = torch.zeros( - (batch_size, attention_mask_len), device=device, dtype=torch.bool - ) + valid_mask = torch.zeros((batch_size, attention_mask_len), device=device, dtype=torch.bool) source_mask = attention_mask.to(device=device, dtype=torch.bool) copy_len = min(int(source_mask.shape[-1]), attention_mask_len) if copy_len > 0: @@ -3734,15 +3461,11 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): if attention_mask_len > current_length: valid_mask[:, current_length:] = False else: - raise ValueError( - f"Unsupported attention_mask shape for MolmoAct2: {tuple(attention_mask.shape)}" - ) + raise ValueError(f"Unsupported attention_mask shape for MolmoAct2: {tuple(attention_mask.shape)}") valid_mask = valid_mask[:, None, None, :] causal_mask = torch.tril( - torch.ones( - attention_mask_len, attention_mask_len, device=device, dtype=torch.bool - ) + torch.ones(attention_mask_len, attention_mask_len, device=device, dtype=torch.bool) )[None, None, past_length:current_length, :attention_mask_len] if token_type_ids is not None and past_length == 0: @@ -3751,8 +3474,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): can_attend_back = image_mask[:, :, None] & image_mask[:, None, :] image_len = min(int(token_type_ids.shape[1]), attention_mask_len) causal_mask[:, :, :, :image_len] = ( - causal_mask[:, :, :, :image_len] - | can_attend_back[:, None, :, :image_len] + causal_mask[:, :, :, :image_len] | can_attend_back[:, None, :, :image_len] ) allowed = valid_mask & causal_mask @@ -3791,21 +3513,15 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): ) -> Union[tuple, MolmoAct2ModelOutputWithPast]: output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") images, token_pooling = self.merge_visual_inputs( input_ids=input_ids, @@ -3819,9 +3535,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): ) if images is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both images and inputs_embeds at the same time." - ) + raise ValueError("You cannot specify both images and inputs_embeds at the same time.") if inputs_embeds is None: inputs_embeds, image_features = self.build_input_embeddings( @@ -3913,9 +3627,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi if stats is not None: return stats filename = getattr(self.config, "norm_stats_filename", "norm_stats.json") - base_dir = getattr(self.config, "_name_or_path", None) or getattr( - self, "name_or_path", None - ) + base_dir = getattr(self.config, "_name_or_path", None) or getattr(self, "name_or_path", None) if not base_dir: raise ValueError( "MolmoAct2 normalization stats are not loaded and config._name_or_path is empty; " @@ -3939,9 +3651,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi return stats @staticmethod - def _move_inputs_to_device( - inputs: Mapping[str, Any], device: torch.device - ) -> Dict[str, Any]: + def _move_inputs_to_device(inputs: Mapping[str, Any], device: torch.device) -> Dict[str, Any]: out = {} for key, value in inputs.items(): out[key] = value.to(device) if torch.is_tensor(value) else value @@ -3951,9 +3661,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi def _drop_trivial_attention_mask(inputs: Mapping[str, Any]) -> Dict[str, Any]: out = dict(inputs) attention_mask = out.get("attention_mask") - if torch.is_tensor(attention_mask) and bool( - attention_mask.to(dtype=torch.bool).all().item() - ): + if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()): out.pop("attention_mask", None) return out @@ -3982,9 +3690,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi ) if int(action_dim) == int(max_action_dim): return None - mask = torch.ones( - (int(batch_size), int(max_action_dim)), device=device, dtype=torch.bool - ) + mask = torch.ones((int(batch_size), int(max_action_dim)), device=device, dtype=torch.bool) mask[:, : int(action_dim)] = False return mask @@ -4005,35 +3711,24 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi start = int(n_obs_steps) - 1 end = start + int(n_action_steps) if end > actions.shape[1]: - raise ValueError( - f"Requested actions up to {end} but model produced horizon {actions.shape[1]}." - ) + raise ValueError(f"Requested actions up to {end} but model produced horizon {actions.shape[1]}.") return actions[:, start:end] def _depth_token_id_to_bin(self) -> Dict[int, int]: - if ( - self.config.depth_token_start_id is None - or int(self.config.num_depth_tokens or 0) <= 0 - ): + if self.config.depth_token_start_id is None or int(self.config.num_depth_tokens or 0) <= 0: return {} start = int(self.config.depth_token_start_id) return {start + idx: idx for idx in range(int(self.config.num_depth_tokens))} def _action_token_id_to_bin(self) -> Dict[int, int]: - if ( - self.config.action_token_start_id is None - or int(self.config.num_action_tokens or 0) <= 0 - ): + if self.config.action_token_start_id is None or int(self.config.num_action_tokens or 0) <= 0: return {} start = int(self.config.action_token_start_id) return {start + idx: idx for idx in range(int(self.config.num_action_tokens))} def _require_eos_token_id(self) -> int: eos_token_id = getattr(self.config, "eos_token_id", None) - if ( - eos_token_id is None - and getattr(self, "generation_config", None) is not None - ): + if eos_token_id is None and getattr(self, "generation_config", None) is not None: eos_token_id = getattr(self.generation_config, "eos_token_id", None) if isinstance(eos_token_id, (list, tuple)): eos_token_id = eos_token_id[0] if eos_token_id else None @@ -4043,21 +3738,12 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi ) return int(eos_token_id) - def _decode_depth_bins_from_token_ids( - self, token_ids: torch.Tensor - ) -> torch.Tensor: - if ( - self.config.depth_start_token_id is None - or self.config.depth_end_token_id is None - ): - raise RuntimeError( - "Depth generation requires / token IDs." - ) + def _decode_depth_bins_from_token_ids(self, token_ids: torch.Tensor) -> torch.Tensor: + if self.config.depth_start_token_id is None or self.config.depth_end_token_id is None: + raise RuntimeError("Depth generation requires / token IDs.") token_id_to_bin = self._depth_token_id_to_bin() if not token_id_to_bin: - raise RuntimeError( - "Depth generation requires indexed depth tokens in the converted config." - ) + raise RuntimeError("Depth generation requires indexed depth tokens in the converted config.") depth_token_bins = _extract_discrete_token_bins( _flatten_generated_token_ids(token_ids), int(self.config.depth_start_token_id), @@ -4065,9 +3751,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi token_id_to_bin, ) if not depth_token_bins: - raise RuntimeError( - "Model generated no decodable depth tokens between /." - ) + raise RuntimeError("Model generated no decodable depth tokens between /.") return torch.as_tensor([depth_token_bins], device=self.device, dtype=torch.long) def _consume_generation_tokens( @@ -4082,9 +3766,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi elif token_ids.ndim == 2: next_input_ids = token_ids else: - raise ValueError( - f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}." - ) + raise ValueError(f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}.") next_attention_mask = attention_mask if next_attention_mask is not None: past_length = _cache_seq_len_int(past_key_values) @@ -4094,9 +3776,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi next_attention_mask = torch.cat( ( next_attention_mask, - next_attention_mask.new_ones( - (next_input_ids.shape[0], pad_len) - ), + next_attention_mask.new_ones((next_input_ids.shape[0], pad_len)), ), dim=-1, ) @@ -4124,18 +3804,14 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi layers = getattr(past_key_values, "layers", None) max_cache_len = int(getattr(layers[0], "max_cache_len", 0)) if layers else 0 if max_cache_len <= 0: - raise RuntimeError( - "Depth decode fast path requires a cache with a fixed maximum length." - ) + raise RuntimeError("Depth decode fast path requires a cache with a fixed maximum length.") input_ids = inputs["input_ids"] batch_size = int(input_ids.shape[0]) device = input_ids.device dtype = self.lm_head.weight.dtype positions = torch.arange(max_cache_len, device=device, dtype=torch.long) - valid_mask = torch.ones( - (batch_size, max_cache_len), device=device, dtype=torch.bool - ) + valid_mask = torch.ones((batch_size, max_cache_len), device=device, dtype=torch.bool) attention_mask = inputs.get("attention_mask") if attention_mask is not None: source_mask = attention_mask.to(device=device, dtype=torch.bool) @@ -4172,9 +3848,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi elif token_ids.ndim == 2: next_input_ids = token_ids else: - raise ValueError( - f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}." - ) + raise ValueError(f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}.") past_length = _cache_seq_len_int(past_key_values) end = past_length + int(next_input_ids.shape[1]) if self.depth_decode_cuda_graph_manager.can_use( @@ -4188,9 +3862,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi attention_bias=attention_bias, past_length=past_length, ) - cache_position = torch.arange( - past_length, end, device=next_input_ids.device, dtype=torch.long - ) + cache_position = torch.arange(past_length, end, device=next_input_ids.device, dtype=torch.long) attention_bias = attention_bias[:, :, past_length:end, :end] inputs_embeds = self._embed_base_tokens(next_input_ids) outputs = self.model.transformer( @@ -4241,10 +3913,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi max_end_steps = max(8, action_horizon) action_token_budget = max(1, action_horizon * 16) return self.depth_decode_cuda_graph_manager.make_static_cache( - max_cache_len=prompt_len - + self._max_depth_decode_steps() - + max_end_steps - + action_token_budget, + max_cache_len=prompt_len + self._max_depth_decode_steps() + max_end_steps + action_token_budget, ) def _continue_discrete_generation_from_output( @@ -4301,17 +3970,9 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi depth_cache: Optional[Mapping[str, Any]], enable_adaptive_depth: bool, ) -> _DepthPrefix: - if ( - self.config.depth_start_token_id is None - or self.config.depth_end_token_id is None - ): - raise RuntimeError( - "Depth reasoning requires single-token /." - ) - if ( - self.config.depth_token_start_id is None - or int(self.config.num_depth_tokens or 0) <= 0 - ): + if self.config.depth_start_token_id is None or self.config.depth_end_token_id is None: + raise RuntimeError("Depth reasoning requires single-token /.") + if self.config.depth_token_start_id is None or int(self.config.num_depth_tokens or 0) <= 0: raise RuntimeError("Depth reasoning requires indexed depth tokens.") batch_size = int(inputs["input_ids"].shape[0]) if batch_size != 1 and enable_adaptive_depth: @@ -4329,12 +3990,10 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi for _ in range(max_steps): next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1) generated_tokens.append(next_token) - current_output, current_attention_mask = ( - self._consume_generation_tokens( - next_token, - past_key_values=current_past_key_values, - attention_mask=current_attention_mask, - ) + current_output, current_attention_mask = self._consume_generation_tokens( + next_token, + past_key_values=current_past_key_values, + attention_mask=current_attention_mask, ) current_past_key_values = current_output.past_key_values if bool((next_token == int(self.config.depth_end_token_id)).all()): @@ -4343,16 +4002,12 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi if not generated_tokens: raise RuntimeError("Depth generation produced no tokens.") if not hit_depth_end: - raise RuntimeError( - f"Depth generation did not emit within {max_steps} steps." - ) + raise RuntimeError(f"Depth generation did not emit within {max_steps} steps.") depth_token_ids = torch.stack(generated_tokens, dim=1) full_input_ids = torch.cat([inputs["input_ids"], depth_token_ids], dim=1) full_attention_mask = None if current_attention_mask is not None: - full_attention_mask = current_attention_mask[ - :, : full_input_ids.shape[1] - ] + full_attention_mask = current_attention_mask[:, : full_input_ids.shape[1]] encoder_kv_states = self.model._extract_kv_states(current_past_key_values) return _DepthPrefix( token_ids=depth_token_ids, @@ -4376,9 +4031,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi device=self.device, dtype=torch.long, ) - depth_attention_bias = self._make_depth_decode_attention_bias( - inputs, current_past_key_values - ) + depth_attention_bias = self._make_depth_decode_attention_bias(inputs, current_past_key_values) generated_tokens.append(depth_start) last_hidden, current_past_key_values = self._run_depth_decode_step( depth_start, @@ -4436,9 +4089,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi attention_bias=depth_attention_bias, ) else: - for start_idx, end_idx, should_generate in _build_depth_update_spans( - update_mask - ): + for start_idx, end_idx, should_generate in _build_depth_update_spans(update_mask): if should_generate: for depth_idx in range(start_idx, end_idx): depth_logits = self._project_depth_logits(last_hidden) @@ -4446,17 +4097,13 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi depth_bins[:, depth_idx] = predicted_bins chosen_token_ids = code_token_ids[predicted_bins] generated_tokens.append(chosen_token_ids) - last_hidden, current_past_key_values = ( - self._run_depth_decode_step( - chosen_token_ids, - past_key_values=current_past_key_values, - attention_bias=depth_attention_bias, - ) + last_hidden, current_past_key_values = self._run_depth_decode_step( + chosen_token_ids, + past_key_values=current_past_key_values, + attention_bias=depth_attention_bias, ) continue - replay_bins = previous_buffer_t[:, start_idx:end_idx].expand( - batch_size, -1 - ) + replay_bins = previous_buffer_t[:, start_idx:end_idx].expand(batch_size, -1) depth_bins[:, start_idx:end_idx] = replay_bins replay_token_ids = code_token_ids[replay_bins] generated_tokens.extend(replay_token_ids.unbind(dim=1)) @@ -4520,16 +4167,9 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi action_horizon: int, ) -> torch.Tensor: if action_tokenizer is None: - raise ValueError( - "inference_action_mode='discrete' requires an `action_tokenizer` input." - ) - if ( - self.config.action_start_token_id is None - or self.config.action_end_token_id is None - ): - raise RuntimeError( - "Discrete action generation requires / token IDs." - ) + raise ValueError("inference_action_mode='discrete' requires an `action_tokenizer` input.") + if self.config.action_start_token_id is None or self.config.action_end_token_id is None: + raise RuntimeError("Discrete action generation requires / token IDs.") token_id_to_bin = self._action_token_id_to_bin() if not token_id_to_bin: raise RuntimeError( @@ -4559,13 +4199,9 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi elif action_chunk.ndim == 2: action_chunk = action_chunk[None, :, :] elif action_chunk.ndim > 3: - action_chunk = action_chunk.reshape( - 1, action_chunk.shape[-2], action_chunk.shape[-1] - ) + action_chunk = action_chunk.reshape(1, action_chunk.shape[-2], action_chunk.shape[-1]) if action_chunk.ndim != 3: - raise RuntimeError( - f"Decoded action chunk has unexpected shape {action_chunk.shape}." - ) + raise RuntimeError(f"Decoded action chunk has unexpected shape {action_chunk.shape}.") return torch.as_tensor(action_chunk, device=self.device, dtype=torch.float32) @torch.no_grad() @@ -4590,13 +4226,10 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi return_dict: bool = True, ) -> Union[MolmoAct2ActionOutput, torch.Tensor]: if state is None: - raise ValueError( - "MolmoAct2 `predict_action` requires `state` for discrete state prompting." - ) + raise ValueError("MolmoAct2 `predict_action` requires `state` for discrete state prompting.") if inference_action_mode is None: raise ValueError( - "`inference_action_mode` must be provided explicitly as either " - "'continuous' or 'discrete'." + "`inference_action_mode` must be provided explicitly as either 'continuous' or 'discrete'." ) inference_action_mode = str(inference_action_mode) if inference_action_mode not in {"continuous", "discrete"}: @@ -4616,33 +4249,25 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi ) if inference_action_mode == "discrete": if action_tokenizer is None: - raise ValueError( - "inference_action_mode='discrete' requires an `action_tokenizer` input." - ) + raise ValueError("inference_action_mode='discrete' requires an `action_tokenizer` input.") if self.config.action_mode not in {"discrete", "both"}: raise ValueError( "inference_action_mode='discrete' requires checkpoint action_mode in " f"{{'discrete', 'both'}}, got {self.config.action_mode!r}." ) if enable_depth_reasoning and not bool(self.config.enable_depth_reasoning): - raise ValueError( - "this model was not trained with `--enable_depth_reasoning`." - ) + raise ValueError("this model was not trained with `--enable_depth_reasoning`.") stats = self._get_robot_stats() norm_tag = stats.validate_tag(norm_tag) metadata = stats.get_metadata(norm_tag) - normalized_state = np.asarray( - stats.normalize_state(state, norm_tag), dtype=np.float32 - ) + normalized_state = np.asarray(stats.normalize_state(state, norm_tag), dtype=np.float32) num_state_tokens = int(self.config.num_state_tokens or 0) if num_state_tokens <= 0: raise RuntimeError( "Discrete state prompting requires indexed state tokens in the converted config." ) - discrete_state_string = _build_discrete_state_string( - normalized_state, num_state_tokens - ) + discrete_state_string = _build_discrete_state_string(normalized_state, num_state_tokens) style = "robot_depth_action" if enable_depth_reasoning else "robot_action" task_text = str(task or "") if normalize_language: @@ -4679,9 +4304,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi resolved_n_action_steps = int(action_horizon) resolved_n_action_steps = int(resolved_n_action_steps) if resolved_n_action_steps < 1: - raise ValueError( - f"n_action_steps must be >= 1, got {resolved_n_action_steps}." - ) + raise ValueError(f"n_action_steps must be >= 1, got {resolved_n_action_steps}.") if resolved_n_action_steps > int(action_horizon): raise ValueError( f"Requested n_action_steps={resolved_n_action_steps} exceeds tag action_horizon={int(action_horizon)}." @@ -4726,11 +4349,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi if latest_first_image is not None: updated_depth_cache = { "image": latest_first_image, - "depth_bins": depth_bins.detach() - .cpu() - .reshape(-1) - .numpy() - .astype(np.int64), + "depth_bins": depth_bins.detach().cpu().reshape(-1).numpy().astype(np.int64), } else: actions = self.model.generate_actions_from_inputs( @@ -4756,18 +4375,12 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi end_token_id=self._require_eos_token_id(), max_steps=max(1, int(generation_horizon * 16)), ) - generated_token_ids = torch.cat( - [depth_prefix.token_ids, action_token_ids], dim=1 - ) + generated_token_ids = torch.cat([depth_prefix.token_ids, action_token_ids], dim=1) depth_bins = depth_prefix.depth_bins if latest_first_image is not None: updated_depth_cache = { "image": latest_first_image, - "depth_bins": depth_bins.detach() - .cpu() - .reshape(-1) - .numpy() - .astype(np.int64), + "depth_bins": depth_bins.detach().cpu().reshape(-1).numpy().astype(np.int64), } else: max_action_decode_steps = max(1, int(generation_horizon * 16)) @@ -4805,9 +4418,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi ) actions = self._slice_action_dim(actions, action_dim) - actions = self._slice_action_chunk( - actions, int(self.config.n_obs_steps), resolved_n_action_steps - ) + actions = self._slice_action_chunk(actions, int(self.config.n_obs_steps), resolved_n_action_steps) actions = stats.unnormalize_action(actions, norm_tag) if not torch.is_tensor(actions): actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32) @@ -4894,18 +4505,12 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.vocab_size - ) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size) return MolmoAct2CausalLMOutputWithPast( loss=loss, diff --git a/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py index 43a927a71..e01284bc8 100644 --- a/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py @@ -19,6 +19,7 @@ """ Processor class for MolmoAct2. """ + from typing import Optional, Union import dataclasses @@ -50,7 +51,7 @@ IM_START_TOKEN = f"" LOW_RES_IMAGE_START_TOKEN = f"" FRAME_START_TOKEN = f"" IM_END_TOKEN = f"" -FRAME_END_TOKEN= f"" +FRAME_END_TOKEN = f"" IM_COL_TOKEN = f"" IMAGE_PROMPT = "<|image|>" VIDEO_PROMPT = "<|video|>" @@ -69,6 +70,7 @@ IMAGE_TOKENS = [ class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False): """MolmoAct2 processor kwargs""" + images_kwargs: MolmoAct2ImagesKwargs videos_kwargs: MolmoAct2VideoProcessorKwargs _defaults = { @@ -106,7 +108,7 @@ class MolmoAct2Processor(ProcessorMixin): use_single_crop_start_token: Optional[bool] = True, video_use_col_tokens: Optional[bool] = False, use_frame_special_tokens: Optional[bool] = True, - **kwargs + **kwargs, ) -> None: super().__init__( image_processor, @@ -122,10 +124,7 @@ class MolmoAct2Processor(ProcessorMixin): self.image_placeholder_token = IMAGE_PROMPT self.video_placeholder_token = VIDEO_PROMPT - self.image_token_ids = [ - tokenizer.convert_tokens_to_ids(token) - for token in IMAGE_TOKENS - ] + self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS] def get_image_tokens(self, image_grid: np.ndarray): resized_h, resized_w, height, width = image_grid @@ -158,11 +157,7 @@ class MolmoAct2Processor(ProcessorMixin): if self.use_single_crop_col_tokens is None else self.use_single_crop_col_tokens ) - image_start_token = ( - LOW_RES_IMAGE_START_TOKEN - if self.use_single_crop_start_token - else IM_START_TOKEN - ) + image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN if use_single_crop_col_tokens: per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0) joint = [ @@ -190,7 +185,7 @@ class MolmoAct2Processor(ProcessorMixin): for frame_idx, frame_time in enumerate(timestamps): # `per-frame-compact` time mode prev_space = " " if frame_idx > 0 else "" - frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens + frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens video_string += frame_prefix per_row = np.full(w, IMAGE_PATCH_TOKEN) @@ -249,8 +244,8 @@ class MolmoAct2Processor(ProcessorMixin): attention_mask = attention_mask[0] return input_ids, attention_mask else: - new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype) - new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype) + new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype) + new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype) src_idx = np.tile(np.arange(S), (B, 1)) # [B, S] valid_mask = src_idx >= first_valid_index[:, None] # [B, S] @@ -349,13 +344,13 @@ class MolmoAct2Processor(ProcessorMixin): if not isinstance(text, list): text = [text] - text = text.copy() # below lines change text in-place + text = text.copy() # below lines change text in-place if image_grids is not None: index = 0 for i in range(len(text)): num_images = text[i].count(self.image_placeholder_token) - image_grids_i = image_grids[index:index+num_images] + image_grids_i = image_grids[index : index + num_images] for image_grid in image_grids_i: image_tokens = self.get_image_tokens(image_grid) image_string = "".join(image_tokens) @@ -367,8 +362,8 @@ class MolmoAct2Processor(ProcessorMixin): for i in range(len(text)): num_videos = text[i].count(self.video_placeholder_token) assert num_videos in {0, 1}, "At most one video is supported for now" - video_grids_i = video_grids[index:index+num_videos] - metadata_i = video_metadata[index:index+num_videos] + video_grids_i = video_grids[index : index + num_videos] + metadata_i = video_metadata[index : index + num_videos] for video_grid, metadata in zip(video_grids_i, metadata_i): video_string = self.get_video_string( video_grid, diff --git a/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py index 1432633de..01763370a 100644 --- a/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py @@ -17,6 +17,7 @@ # ruff: noqa """Video processor class for MolmoAct2""" + from functools import partial import os import warnings @@ -100,7 +101,9 @@ def resize_image( )(image) resized = torch.clip(resized, 0.0, 1.0).to(dtype) else: - assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype) + assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format( + image.dtype + ) in_min = 0.0 in_max = 255.0 resized = torchvision.transforms.Resize( @@ -130,14 +133,16 @@ def build_resized_image( image_patch_size: int, ) -> tuple[np.ndarray, np.ndarray]: resized = resize_image( - image, base_image_input_size, resample, + image, + base_image_input_size, + resample, ) resized = normalize_image(resized, image_mean, image_std) if len(resized.shape) == 3: resized = np.expand_dims(resized, 0) crop_patch_w = base_image_input_size[1] // image_patch_size crop_patch_h = base_image_input_size[0] // image_patch_size - resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w]) + resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w]) return resized, resize_idx @@ -145,19 +150,19 @@ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray: """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]""" if len(array.shape) == 3: n_crops, h, w = array.shape - h_patches = h//patch_size - w_patches = w//patch_size + h_patches = h // patch_size + w_patches = w // patch_size array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size]) array = np.transpose(array, [0, 1, 3, 2, 4]) - array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size]) return array else: n_crops, h, w, c = array.shape - h_patches = h//patch_size - w_patches = w//patch_size + h_patches = h // patch_size + w_patches = w // patch_size array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c]) array = np.transpose(array, [0, 1, 3, 2, 4, 5]) - array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c]) + array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c]) return array @@ -168,10 +173,13 @@ def arange_for_pooling( ) -> np.ndarray: h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0] w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1] - idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]], - mode='constant',constant_values=-1) - return einops.rearrange( - idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) + idx_arr = np.pad( + idx_arr, + [[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]], + mode="constant", + constant_values=-1, + ) + return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) def image_to_patches_and_grids( @@ -206,7 +214,7 @@ def image_to_patches_and_grids( ) pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) h, w = pooling_idx.shape[:2] - pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w]) + pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w]) image_grid = [h, w] return ( image_grid, @@ -277,6 +285,7 @@ def read_video_decord( """ # Lazy import from decord import importlib + decord = importlib.import_module("decord") vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu @@ -296,7 +305,7 @@ def read_video_decord( target_timestamps = np.array(target_timestamps) offset = time_stamps[0, 0] - ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right') + ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right") ix = np.minimum(ix, len(time_stamps) - 1) video = vr.get_batch(ix).asnumpy() @@ -331,6 +340,7 @@ def read_video_torchcodec( """ # Lazy import torchcodec import importlib + torchcodec = importlib.import_module("torchcodec") decoder = torchcodec.decoders.VideoDecoder( @@ -360,7 +370,7 @@ def read_video_torchcodec( # Floating point/rounding issues might cause `target_timestamps` to be very slightly # out-of-bounds, to handle this we sanity check then clip them assert all(x >= 0 for x in target_timestamps) - assert all(x < duration+1e-6 for x in target_timestamps) + assert all(x < duration + 1e-6 for x in target_timestamps) # 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the # exact boundary value, we should still get the first/last frame anyway max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6 @@ -369,7 +379,9 @@ def read_video_torchcodec( timestamps = [x + time_offset for x in target_timestamps] timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps] - video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format + video = ( + decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) + ) # Convert to THWC format target_timestamps = np.array(target_timestamps) metadata.frames_indices = target_timestamps * metadata.fps @@ -397,6 +409,7 @@ def read_video_pyav( """ # Lazy import torchcodec import importlib + av = importlib.import_module("av") with av.open(video_path) as container: @@ -413,7 +426,7 @@ def read_video_pyav( if container_end is None or container_end < frames[-1].pts: # Some problem with stream duration, so use the frame PTS directly # and guess the duration of the last frame - end = frames[-1].pts * stream.time_base + 1/fps + end = frames[-1].pts * stream.time_base + 1 / fps else: end = container_end duration = float(end - start) @@ -432,7 +445,7 @@ def read_video_pyav( target_timestamps = np.array(target_timestamps) end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration]) - indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right') + indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right") indices = np.minimum(indices, len(end_time_stamps) - 1) video = np.stack( @@ -480,6 +493,7 @@ def load_video( raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.") # Lazy import from yt_dlp import importlib + yt_dlp = importlib.import_module("yt_dlp") buffer = BytesIO() @@ -492,7 +506,9 @@ def load_video( elif os.path.isfile(video): file_obj = video else: - raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.") + raise TypeError( + "Incorrect format used for video. Should be an url linking to an video or a local path." + ) # can also load with decord, but not cv2/torchvision # both will fail in case of url links @@ -551,12 +567,7 @@ def get_target_fps( return selected_target_fps -def get_frame_times_and_chosen_fps( - selected_target_fps, - total_frames, - max_frames, - video_fps -): +def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps): if selected_target_fps is None: frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int) else: @@ -656,19 +667,15 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): return times elif frame_sample_mode == "uniform_last_frame": if max_fps is not None: - max_duration = (num_frames-1) / max_fps # -1 to include the last frame + max_duration = (num_frames - 1) / max_fps # -1 to include the last frame if max_duration < duration: - times = np.linspace( - 0, duration, num=num_frames, endpoint=True, dtype=np.float64 - ) + times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64) else: - times = np.arange(0.0, stop=duration, step=1/max_fps) + times = np.arange(0.0, stop=duration, step=1 / max_fps) times = np.concatenate([times, [duration]], axis=0) assert len(times) <= num_frames else: - times = np.linspace( - 0, duration, num=num_frames, endpoint=True, dtype=np.float64 - ) + times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64) return times else: raise NotImplementedError(frame_sample_mode) @@ -717,7 +724,9 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): return indices else: float_indices = np.arange( - 0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps), + 0.0, + stop=total_num_frames - 1, + step=float(metadata.fps / max_fps), ) if np.round(float_indices[-1]) != total_num_frames - 1: float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0) @@ -727,7 +736,10 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): return indices elif frame_sample_mode == "uniform_last_frame": indices = np.linspace( - 0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True, + 0, + total_num_frames - 1, + num=min(num_frames, total_num_frames), + endpoint=True, ).astype(int) return indices elif frame_sample_mode == "fps": @@ -750,9 +762,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): raise NotImplementedError(frame_sample_mode) def fetch_videos( - self, - video_url_or_urls: Union[str, list[str], list[list[str]]], - sample_timestamps_fn=None + self, video_url_or_urls: Union[str, list[str], list[list[str]]], sample_timestamps_fn=None ): """ Convert a single or a list of urls into the corresponding `np.array` objects. @@ -760,11 +770,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): If a single url is passed, the return value will be a single object. If a list is passed a list of objects is returned. """ - if ( - (not is_decord_available()) - and (not is_torchcodec_available()) - and (not is_av_available()) - ): + if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()): raise ImportError( "MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed." ) @@ -785,7 +791,14 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): backend = "pyav" if isinstance(video_url_or_urls, list): - return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls])) + return list( + zip( + *[ + self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) + for x in video_url_or_urls + ] + ) + ) else: return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn) @@ -823,9 +836,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): "Will decode the video and sample frames using MolmoAct2's default sampling mode" ) if isinstance(videos[0], list): - raise ValueError( - "A list of images is not supported for video input!" - ) + raise ValueError("A list of images is not supported for video input!") else: videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn) @@ -975,7 +986,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): pixel_values_videos = np.concatenate(batch_crops, 0) video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0) - data =dict( + data = dict( pixel_values_videos=pixel_values_videos, video_token_pooling=video_token_pooling, video_grids=video_grids, diff --git a/src/lerobot/policies/molmoact2/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/modeling_molmoact2.py index 288837268..f86be0904 100644 --- a/src/lerobot/policies/molmoact2/modeling_molmoact2.py +++ b/src/lerobot/policies/molmoact2/modeling_molmoact2.py @@ -136,7 +136,6 @@ def _sample_beta_timesteps( return time_offset + scale * samples - class MolmoAct2Policy(PreTrainedPolicy): config_class = MolmoAct2Config name = "molmoact2"