diff --git a/pyproject.toml b/pyproject.toml index 4afdb63d8..fc686a6c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -405,6 +405,8 @@ default.extend-ignore-identifiers-re = [ "ein", "thw", "inpt", + "arange", + "is_compileable", "ROBOTIS", "OT_VALUE", "VanderBilt" diff --git a/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py index c66c81fe0..29da68c14 100644 --- a/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/configuration_molmoact2.py @@ -206,7 +206,7 @@ class MolmoAct2TextConfig(PretrainedConfig): self, hidden_size: int = 3584, num_attention_heads: int = 28, - num_key_value_heads: Optional[int] = 4, + num_key_value_heads: int | None = 4, head_dim: int = 128, vocab_size: int = 152064, additional_vocab_size: int = 128, @@ -220,7 +220,7 @@ class MolmoAct2TextConfig(PretrainedConfig): max_position_embeddings: int = 4096, rope_theta: float = 1000000.0, rope_scaling: dict[str, Any] = None, - rope_scaling_layers: Optional[list[int]] = None, + rope_scaling_layers: list[int] | None = None, use_qk_norm: bool = False, qk_norm_type: str = "olmo", layer_norm_eps: int = 1e-6, diff --git a/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py index 16d2afec9..a172c8477 100644 --- a/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/image_processing_molmoact2.py @@ -364,11 +364,11 @@ def image_to_patches_and_grids( class MolmoAct2ImagesKwargs(ImagesKwargs, total=False): - max_crops: Optional[int] - overlap_margins: Optional[list[int]] - crop_mode: Optional[str] - patch_size: Optional[int] - pooling_size: Optional[list[int]] + max_crops: int | None + overlap_margins: list[int] | None + crop_mode: str | None + patch_size: int | None + pooling_size: list[int] | None class MolmoAct2ImageProcessor(BaseImageProcessor): @@ -400,10 +400,10 @@ class MolmoAct2ImageProcessor(BaseImageProcessor): def __init__( self, - size: Optional[dict[str, int]] = None, + size: dict[str, int] | None = None, resample: PILImageResampling = PILImageResampling.BILINEAR, - image_mean: Optional[Union[float, list[float]]] = None, - image_std: Optional[Union[float, list[float]]] = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, do_convert_rgb: bool = True, max_crops: int = 8, overlap_margins: list[int] = [4, 4], @@ -431,17 +431,17 @@ class MolmoAct2ImageProcessor(BaseImageProcessor): def preprocess( self, images: ImageInput, - size: Optional[dict[str, int]] = None, - resample: Optional[PILImageResampling] = None, - image_mean: Optional[Union[float, list[float]]] = None, - image_std: Optional[Union[float, list[float]]] = None, - do_convert_rgb: Optional[bool] = None, - max_crops: Optional[int] = None, - overlap_margins: Optional[list[int]] = None, - crop_mode: Optional[str] = None, - patch_size: Optional[int] = None, - pooling_size: Optional[list[int]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + size: dict[str, int] | None = None, + resample: PILImageResampling | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool | None = None, + max_crops: int | None = None, + overlap_margins: list[int] | None = None, + crop_mode: str | None = None, + patch_size: int | None = None, + pooling_size: list[int] | None = None, + return_tensors: str | TensorType | None = None, **kwargs, ) -> BatchFeature: """ diff --git a/src/lerobot/policies/molmoact2/hf_model/inference.py b/src/lerobot/policies/molmoact2/hf_model/inference.py index 1bfcb8178..2c0243880 100644 --- a/src/lerobot/policies/molmoact2/hf_model/inference.py +++ b/src/lerobot/policies/molmoact2/hf_model/inference.py @@ -19,7 +19,8 @@ """Inference utilities for MolmoAct2""" from dataclasses import dataclass -from typing import Any, Iterable, Optional, Sequence, Tuple +from typing import Any, Optional, Tuple +from collections.abc import Iterable, Sequence import torch from torch.nn import functional as F @@ -32,12 +33,12 @@ class _ActionFlowInputs: trajectory: torch.Tensor context: Any modulations: Sequence[Any] - action_dim_is_pad: Optional[torch.Tensor] + action_dim_is_pad: torch.Tensor | None @dataclass class _ActionFlowCudaGraph: - key: Tuple[Any, ...] + key: tuple[Any, ...] graph: torch.cuda.CUDAGraph static_inputs: _ActionFlowInputs output: torch.Tensor @@ -59,7 +60,7 @@ class _DepthDecodeCudaGraphPostStage: @dataclass class _DepthDecodeCudaGraph: - cache_key: Tuple[Any, ...] + cache_key: tuple[Any, ...] pre_graph: torch.cuda.CUDAGraph token_ids: torch.Tensor cos: torch.Tensor @@ -73,13 +74,13 @@ class _DepthDecodeCudaGraph: @dataclass class _DepthDecodeCudaGraphSpec: eligible: bool - cache_key_prefix: Tuple[Any, ...] + cache_key_prefix: tuple[Any, ...] num_hidden_layers: int head_dim: int num_attention_heads: int -def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int: +def _cache_seq_len_int(past_key_values: Cache | None) -> int: if past_key_values is None: return 0 seq_len = past_key_values.get_seq_length() @@ -88,7 +89,7 @@ def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int: return int(seq_len) -def _cache_max_len_int(past_key_values: Optional[Cache]) -> int: +def _cache_max_len_int(past_key_values: Cache | None) -> int: if past_key_values is None: return -1 max_len = past_key_values.get_max_cache_shape() @@ -99,7 +100,7 @@ def _cache_max_len_int(past_key_values: Optional[Cache]) -> int: def _iter_cache_key_values( past_key_values: Cache, -) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]: +) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]: layers = getattr(past_key_values, "layers", None) if layers is not None: for layer in layers: @@ -116,8 +117,8 @@ class _DepthDecodeStaticLayerCache: def __init__(self, max_cache_len: int) -> None: self.max_cache_len = int(max_cache_len) self.cumulative_length = 0 - self.keys: Optional[torch.Tensor] = None - self.values: Optional[torch.Tensor] = None + self.keys: torch.Tensor | None = None + self.values: torch.Tensor | None = None def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: bsz, n_heads = key_states.shape[:2] @@ -138,7 +139,7 @@ class _DepthDecodeStaticLayerCache: value_states: torch.Tensor, *args, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: if self.keys is None: self._allocate(key_states, value_states) start = self.cumulative_length @@ -185,7 +186,7 @@ class ActionCudaGraphManager: def __init__(self, model: Any) -> None: self.model = model self.enabled = True - self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None + self.action_flow_graph: _ActionFlowCudaGraph | None = None def set_enabled(self, enabled: bool) -> None: self.enabled = bool(enabled) @@ -256,8 +257,8 @@ class DepthDecodeCudaGraphManager: self.model = model self.backbone = model.model self.enabled = True - self.graph: Optional[_DepthDecodeCudaGraph] = None - self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None + self.graph: _DepthDecodeCudaGraph | None = None + self.graph_spec: _DepthDecodeCudaGraphSpec | None = None def set_enabled(self, enabled: bool) -> None: self.enabled = bool(enabled) @@ -320,7 +321,7 @@ class DepthDecodeCudaGraphManager: self, next_input_ids: torch.Tensor, attention_bias: torch.Tensor, - ) -> Tuple[Any, ...]: + ) -> tuple[Any, ...]: device = next_input_ids.device return ( self._depth_decode_spec().cache_key_prefix, @@ -341,7 +342,7 @@ class DepthDecodeCudaGraphManager: hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: block = self.backbone.transformer.blocks[layer_idx] attention = block.self_attn residual = hidden_states @@ -378,7 +379,7 @@ class DepthDecodeCudaGraphManager: token_ids: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: inputs_embeds = self.model._embed_base_tokens(token_ids) return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin) @@ -408,7 +409,7 @@ class DepthDecodeCudaGraphManager: attn_context: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context) return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin) @@ -553,7 +554,7 @@ class DepthDecodeCudaGraphManager: past_key_values: Cache, attention_bias: torch.Tensor, past_length: int, - ) -> Tuple[torch.Tensor, Cache]: + ) -> tuple[torch.Tensor, Cache]: end = past_length + 1 decode_graph = self._get_depth_decode_graph( next_input_ids, @@ -582,8 +583,8 @@ class DepthDecodeCudaGraphManager: def _cuda_graph_tensor_signature( - tensor: Optional[torch.Tensor], -) -> Optional[Tuple[Any, ...]]: + tensor: torch.Tensor | None, +) -> tuple[Any, ...] | None: if tensor is None: return None return ( @@ -594,7 +595,7 @@ def _cuda_graph_tensor_signature( ) -def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]: +def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]: sig = _cuda_graph_tensor_signature return ( tuple((sig(k), sig(v)) for k, v in context.kv_contexts), @@ -605,7 +606,7 @@ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]: ) -def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]: +def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]: sig = _cuda_graph_tensor_signature return tuple( ( @@ -617,7 +618,7 @@ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, . ) -def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]: +def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]: sig = _cuda_graph_tensor_signature return ( sig(inputs.trajectory), @@ -628,7 +629,7 @@ def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]: ) -def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: +def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None: if tensor is None: return None static = torch.empty_strided( @@ -711,7 +712,7 @@ def _apply_rotary_pos_emb( cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (_rotate_half(q) * sin) @@ -732,7 +733,7 @@ def _capture_cuda_graph( device: torch.device, *, after_warmup=None, -) -> Tuple[torch.cuda.CUDAGraph, Any]: +) -> tuple[torch.cuda.CUDAGraph, Any]: warmup_stream = torch.cuda.Stream(device=device) warmup_stream.wait_stream(torch.cuda.current_stream(device)) with torch.cuda.stream(warmup_stream): diff --git a/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py index e0e026c4f..4c36b04c8 100644 --- a/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py @@ -24,7 +24,8 @@ import os import re from copy import deepcopy from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union +from collections.abc import Callable, Mapping, Sequence import numpy as np import torch @@ -75,20 +76,20 @@ from .inference import ( logger = logging.get_logger(__name__) -ACTION_START_TOKEN = "" -ACTION_END_TOKEN = "" -ACTION_OUTPUT_TOKEN = "" -STATE_START_TOKEN = "" -STATE_END_TOKEN = "" -STATE_TOKEN_PREFIX = " Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: half_dim = self.head_dim // 2 inv_freq = 1.0 / ( self.base ** (torch.arange(0, half_dim, device=device, dtype=torch.float32) / max(half_dim, 1)) @@ -206,8 +207,8 @@ class ActionExpertRotaryEmbedding(nn.Module): q: torch.Tensor, k: torch.Tensor, *, - rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + rope_cache: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: if rope_cache is None: rope_cache = self.build_cache(seq_len=q.shape[-2], device=q.device, dtype=q.dtype) cos, sin = rope_cache @@ -246,7 +247,7 @@ class ActionExpertSelfAttention(nn.Module): self.out_proj = nn.Linear(hidden_size, hidden_size) self.out_drop = nn.Dropout(proj_dropout) - def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.q_norm is None or self.k_norm is None: return q, k return self.q_norm(q), self.k_norm(k) @@ -257,7 +258,7 @@ class ActionExpertSelfAttention(nn.Module): k: torch.Tensor, v: torch.Tensor, *, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, is_causal: bool = False, ) -> torch.Tensor: dropout_p = self.attn_dropout if self.training else 0.0 @@ -275,9 +276,9 @@ class ActionExpertSelfAttention(nn.Module): self, x: torch.Tensor, *, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, is_causal: bool = False, - rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rope_cache: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: bsz, seq_len, _ = x.shape qkv = self.qkv(x).view(bsz, seq_len, 3, self.num_heads, self.head_dim) @@ -336,7 +337,7 @@ class ActionExpertCrossAttention(nn.Module): k: torch.Tensor, v: torch.Tensor, *, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: dropout_p = self.attn_dropout if self.training else 0.0 out = F.scaled_dot_product_attention( @@ -355,7 +356,7 @@ class ActionExpertCrossAttention(nn.Module): *, kv_k: torch.Tensor, kv_v: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: bsz, tgt_len, _ = x.shape q = self.q_proj(x).view(bsz, tgt_len, self.num_heads, self.head_dim) @@ -453,12 +454,12 @@ class ActionExpertBlock(nn.Module): x: torch.Tensor, conditioning: torch.Tensor, *, - cross_kv: Tuple[torch.Tensor, torch.Tensor], - self_attn_mask: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + cross_kv: tuple[torch.Tensor, torch.Tensor], + self_attn_mask: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, is_causal: bool = False, - modulation: Optional[Tuple[torch.Tensor, ...]] = None, - rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + modulation: tuple[torch.Tensor, ...] | None = None, + rope_cache: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: if modulation is None: modulation = self.modulation(conditioning).chunk(9, dim=1) @@ -501,7 +502,7 @@ class ActionExpertFinalLayer(nn.Module): x: torch.Tensor, conditioning: torch.Tensor, *, - modulation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + modulation: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: if modulation is None: modulation = self.modulation(conditioning).chunk(2, dim=1) @@ -565,8 +566,8 @@ class ActionExpert(nn.Module): self.context_norm = ( ActionExpertRMSNorm(config.hidden_size, eps=1e-6) if config.context_layer_norm else nn.Identity() ) - self._modulation_cache_key: Optional[Tuple[Any, ...]] = None - self._modulation_cache_value: Optional[Sequence[ActionExpertStepModulation]] = None + self._modulation_cache_key: tuple[Any, ...] | None = None + self._modulation_cache_value: Sequence[ActionExpertStepModulation] | None = None self.blocks = nn.ModuleList( [ ActionExpertBlock( @@ -638,8 +639,8 @@ class ActionExpert(nn.Module): def _prepare_kv_context( self, - encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]], - ) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]: + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]], + ) -> Sequence[tuple[torch.Tensor, torch.Tensor]]: if len(encoder_kv_states) != len(self.blocks): raise ValueError( f"Expected {len(self.blocks)} KV layers for per-layer conditioning, " @@ -657,10 +658,10 @@ class ActionExpert(nn.Module): @staticmethod def _build_cross_attention_mask( - encoder_attention_mask: Optional[torch.Tensor], + encoder_attention_mask: torch.Tensor | None, batch_size: int, dtype: torch.dtype, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: if encoder_attention_mask is None: return None mask = encoder_attention_mask[:, None, None, :].to(dtype=dtype) @@ -668,11 +669,11 @@ class ActionExpert(nn.Module): def _build_self_attention_mask( self, - action_attention_mask: Optional[torch.Tensor], + action_attention_mask: torch.Tensor | None, seq_len: int, device: torch.device, dtype: torch.dtype, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: mask = None if action_attention_mask is not None: valid = action_attention_mask.to(device=device, dtype=torch.bool) @@ -687,10 +688,10 @@ class ActionExpert(nn.Module): def prepare_context( self, *, - encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]], - encoder_attention_mask: Optional[torch.Tensor] = None, - action_attention_mask: Optional[torch.Tensor] = None, - state_embeddings: Optional[torch.Tensor] = None, + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]], + encoder_attention_mask: torch.Tensor | None = None, + action_attention_mask: torch.Tensor | None = None, + state_embeddings: torch.Tensor | None = None, batch_size: int, seq_len: int, device: torch.device, @@ -750,7 +751,7 @@ class ActionExpert(nn.Module): self, timesteps: Sequence[torch.Tensor], *, - cache_key: Optional[Tuple[Any, ...]] = None, + cache_key: tuple[Any, ...] | None = None, ) -> Sequence[ActionExpertStepModulation]: if self.training or cache_key is None: return self.prepare_modulation_cache(timesteps) @@ -767,7 +768,7 @@ class ActionExpert(nn.Module): timesteps: torch.Tensor, *, context: ActionExpertContext, - modulation: Optional[ActionExpertStepModulation] = None, + modulation: ActionExpertStepModulation | None = None, ) -> torch.Tensor: bsz, seq_len, _ = actions.shape if seq_len > self.config.max_action_horizon: @@ -776,7 +777,7 @@ class ActionExpert(nn.Module): ) if modulation is None: conditioning = self._time_conditioning(timesteps) - block_modulations: Sequence[Optional[Tuple[torch.Tensor, ...]]] = [None] * len(self.blocks) + block_modulations: Sequence[tuple[torch.Tensor, ...] | None] = [None] * len(self.blocks) final_modulation = None else: conditioning = modulation.conditioning @@ -810,10 +811,10 @@ class ActionExpert(nn.Module): actions: torch.Tensor, timesteps: torch.Tensor, *, - encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]], - encoder_attention_mask: Optional[torch.Tensor] = None, - action_attention_mask: Optional[torch.Tensor] = None, - state_embeddings: Optional[torch.Tensor] = None, + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]], + encoder_attention_mask: torch.Tensor | None = None, + action_attention_mask: torch.Tensor | None = None, + state_embeddings: torch.Tensor | None = None, ) -> torch.Tensor: bsz, seq_len, _ = actions.shape context = self.prepare_context( @@ -837,7 +838,7 @@ def _to_numpy(value: Any) -> np.ndarray: return np.asarray(value) -def _to_array(value: Any) -> Optional[np.ndarray]: +def _to_array(value: Any) -> np.ndarray | None: if value is None: return None if torch.is_tensor(value): @@ -848,7 +849,7 @@ def _to_array(value: Any) -> Optional[np.ndarray]: return np.asarray(value, dtype=np.float32) -def _to_mask(value: Any, fallback_like: Optional[np.ndarray]) -> Optional[np.ndarray]: +def _to_mask(value: Any, fallback_like: np.ndarray | None) -> np.ndarray | None: if value is None: return None mask = np.asarray(value, dtype=np.bool_) @@ -857,7 +858,7 @@ def _to_mask(value: Any, fallback_like: Optional[np.ndarray]) -> Optional[np.nda return mask -def _feature_dim_from_stats(stats: Optional[Mapping[str, Any]]) -> Optional[int]: +def _feature_dim_from_stats(stats: Mapping[str, Any] | None) -> int | None: if not isinstance(stats, Mapping): return None for key in ( @@ -888,14 +889,14 @@ class _FeatureNormalizer: self, *, mode: str, - mean: Optional[np.ndarray] = None, - std: Optional[np.ndarray] = None, - min_val: Optional[np.ndarray] = None, - max_val: Optional[np.ndarray] = None, - q_low: Optional[np.ndarray] = None, - q_high: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, - zero_mask: Optional[np.ndarray] = None, + mean: np.ndarray | None = None, + std: np.ndarray | None = None, + min_val: np.ndarray | None = None, + max_val: np.ndarray | None = None, + q_low: np.ndarray | None = None, + q_high: np.ndarray | None = None, + mask: np.ndarray | None = None, + zero_mask: np.ndarray | None = None, ): self.mode = mode self.mean = mean @@ -908,7 +909,7 @@ class _FeatureNormalizer: self.zero_mask = zero_mask @classmethod - def from_stats(cls, stats: Optional[Mapping[str, Any]], mode: str) -> Optional["_FeatureNormalizer"]: + def from_stats(cls, stats: Mapping[str, Any] | None, mode: str) -> Optional["_FeatureNormalizer"]: if stats is None: return None raw_mask = stats.get("mask") if isinstance(stats, Mapping) else None @@ -1019,7 +1020,7 @@ class _FeatureNormalizer: class _RobotStats: def __init__(self, payload: Mapping[str, Any]): self.norm_mode = str(payload.get("norm_mode", "min_max")) - self.metadata_by_tag: Dict[str, Dict[str, Any]] = { + self.metadata_by_tag: dict[str, dict[str, Any]] = { str(tag): dict(metadata or {}) for tag, metadata in dict(payload.get("metadata_by_tag") or {}).items() } @@ -1037,7 +1038,7 @@ class _RobotStats: self.norm_mode, ) - def validate_tag(self, norm_tag: Optional[str]) -> str: + def validate_tag(self, norm_tag: str | None) -> str: tag = str(norm_tag or "").strip() if not tag: raise ValueError("MolmoAct2 `predict_action` requires `norm_tag`.") @@ -1046,7 +1047,7 @@ class _RobotStats: raise ValueError(f"Unknown MolmoAct2 normalization tag {tag!r}. Allowed tags: {allowed}.") return tag - def get_metadata(self, norm_tag: Optional[str]) -> Dict[str, Any]: + def get_metadata(self, norm_tag: str | None) -> dict[str, Any]: if norm_tag is None: return {} return dict(self.metadata_by_tag.get(str(norm_tag), {}) or {}) @@ -1059,23 +1060,23 @@ class _RobotStats: normalizer = self.action_normalizers.get(str(norm_tag)) return action if normalizer is None else normalizer.unnormalize(action) - def get_action_dim(self, norm_tag: str) -> Optional[int]: + def get_action_dim(self, norm_tag: str) -> int | None: metadata = self.get_metadata(norm_tag) stats = metadata.get("action_stats") dim = _feature_dim_from_stats(stats) return dim - def get_state_dim(self, norm_tag: str) -> Optional[int]: + def get_state_dim(self, norm_tag: str) -> int | None: metadata = self.get_metadata(norm_tag) return _feature_dim_from_stats(metadata.get("state_stats")) - def get_action_horizon(self, norm_tag: str) -> Optional[int]: + def get_action_horizon(self, norm_tag: str) -> int | None: return self._get_positive_int(norm_tag, "action_horizon") - def get_n_action_steps(self, norm_tag: str) -> Optional[int]: + def get_n_action_steps(self, norm_tag: str) -> int | None: return self._get_positive_int(norm_tag, "n_action_steps") - def _get_positive_int(self, norm_tag: str, key: str) -> Optional[int]: + def _get_positive_int(self, norm_tag: str, key: str) -> int | None: value = self.get_metadata(norm_tag).get(key) if value is None: return None @@ -1102,7 +1103,7 @@ def _normalize_image_for_cache(image: Any) -> np.ndarray: return arr -def _extract_first_image(images: Any) -> Optional[np.ndarray]: +def _extract_first_image(images: Any) -> np.ndarray | None: if images is None: return None if isinstance(images, (list, tuple)): @@ -1170,11 +1171,11 @@ def _compute_depth_update_mask( def _build_depth_update_spans( update_mask: Sequence[bool], -) -> List[Tuple[int, int, bool]]: +) -> list[tuple[int, int, bool]]: flat_mask = np.asarray(update_mask, dtype=np.bool_).reshape(-1) if flat_mask.size == 0: return [] - spans: List[Tuple[int, int, bool]] = [] + spans: list[tuple[int, int, bool]] = [] start = 0 current_value = bool(flat_mask[0]) for idx in range(1, int(flat_mask.shape[0])): @@ -1214,7 +1215,7 @@ def _discretize_normalized_state(state: np.ndarray, num_state_tokens: int) -> np return np.clip(np.rint(scaled).astype(np.int64), 0, int(num_state_tokens) - 1) -def _build_discrete_state_string(state: Optional[np.ndarray], num_state_tokens: int) -> str: +def _build_discrete_state_string(state: np.ndarray | None, num_state_tokens: int) -> str: if state is None: return "" token_ids = _discretize_normalized_state(state, num_state_tokens).reshape(-1) @@ -1279,7 +1280,7 @@ def _build_robot_text( return f"{image_prefix}<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{trigger}" -def _flatten_generated_token_ids(token_ids: torch.Tensor) -> List[int]: +def _flatten_generated_token_ids(token_ids: torch.Tensor) -> list[int]: if token_ids.ndim == 3: return [int(x) for x in token_ids[0, 0].detach().cpu().tolist()] if token_ids.ndim == 2: @@ -1290,11 +1291,11 @@ def _flatten_generated_token_ids(token_ids: torch.Tensor) -> List[int]: def _extract_discrete_token_bins( - generated_ids: List[int], + generated_ids: list[int], start_token_id: int, end_token_id: int, - token_id_to_bin: Dict[int, int], -) -> List[int]: + token_id_to_bin: dict[int, int], +) -> list[int]: start_idx = None end_idx = None for idx, token_id in enumerate(generated_ids): @@ -1317,10 +1318,10 @@ def _extract_discrete_token_bins( @dataclass class MolmoAct2ActionOutput(ModelOutput): - actions: Optional[torch.FloatTensor] = None - generated_token_ids: Optional[torch.LongTensor] = None - depth_bins: Optional[torch.LongTensor] = None - depth_cache: Optional[Dict[str, Any]] = None + actions: torch.FloatTensor | None = None + generated_token_ids: torch.LongTensor | None = None + depth_bins: torch.LongTensor | None = None + depth_cache: dict[str, Any] | None = None @dataclass @@ -1328,10 +1329,10 @@ class _DepthPrefix: token_ids: torch.Tensor depth_bins: torch.Tensor full_input_ids: torch.Tensor - attention_mask: Optional[torch.Tensor] - encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]] + attention_mask: torch.Tensor | None + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] next_output: Any - past_key_values: Optional[Cache] + past_key_values: Cache | None @dataclass @@ -1354,12 +1355,12 @@ class MolmoAct2CausalLMOutputWithPast(ModelOutput): image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[Cache] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - attentions: Optional[tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None @dataclass @@ -1373,11 +1374,11 @@ class MolmoAct2ModelOutputWithPast(BaseModelOutputWithPast): image_hidden_states of the model produced by the vision backbone """ - last_hidden_state: Optional[torch.FloatTensor] = None - past_key_values: Optional[Cache] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - attentions: Optional[tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None class ViTMLP(nn.Module): @@ -1386,7 +1387,7 @@ class ViTMLP(nn.Module): dim: int, hidden_dim: int, hidden_act: str, - device: Union[str, torch.device] = None, + device: str | torch.device = None, ): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device) @@ -1405,11 +1406,11 @@ class ViTMultiHeadDotProductAttention(nn.Module): num_key_value_heads: int, head_dim: int, use_bias: bool = True, - input_dim: Optional[int] = None, + input_dim: int | None = None, float32_attention: bool = True, attention_dropout: float = 0.0, residual_dropout: float = 0.0, - device: Union[str, torch.device] = None, + device: str | torch.device = None, attn_implementation: str = "eager", ): super().__init__() @@ -1465,10 +1466,9 @@ class ViTMultiHeadDotProductAttention(nn.Module): def forward( self, inputs_q: torch.Tensor, - inputs_kv: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + inputs_kv: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: - if inputs_kv is not None: inputs_k = inputs_kv inputs_v = inputs_kv @@ -1558,7 +1558,7 @@ class ViTMultiHeadDotProductAttention(nn.Module): class MolmoAct2VisionBlock(nn.Module): - def __init__(self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None): + def __init__(self, config: MolmoAct2VitConfig, device: str | torch.device = None): super().__init__() self.attention = ViTMultiHeadDotProductAttention( hidden_size=config.hidden_size, @@ -1587,9 +1587,9 @@ class MolmoAct2VisionBlock(nn.Module): class MolmoAct2VisionBlockCollection(nn.Module): - def __init__(self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None): + def __init__(self, config: MolmoAct2VitConfig, device: str | torch.device = None): super().__init__() - self.conifg = config + self.config = config self.resblocks = nn.ModuleList( [MolmoAct2VisionBlock(config, device) for _ in range(config.num_hidden_layers)] ) @@ -1603,7 +1603,7 @@ class MolmoAct2VisionBlockCollection(nn.Module): class MolmoAct2VisionTransformer(nn.Module): - def __init__(self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None): + def __init__(self, config: MolmoAct2VitConfig, device: str | torch.device = None): super().__init__() self.config = config @@ -1638,7 +1638,7 @@ class MolmoAct2VisionTransformer(nn.Module): (patch_num_0, patch_num_1) = patch_num if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: - # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + # Derived from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py # antialias: default True in jax.image.resize pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) pos_emb = F.interpolate( @@ -1679,7 +1679,7 @@ class ImageProjectorMLP(nn.Module): hidden_dim: int, output_dim: int, hidden_act: str, - device: Union[str, torch.device] = None, + device: str | torch.device = None, ): super().__init__() self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device) @@ -1781,8 +1781,7 @@ class MolmoAct2VisionBackbone(nn.Module): self, images: torch.Tensor, pooled_patches_idx: torch.Tensor, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - + ) -> tuple[torch.Tensor, torch.Tensor | None]: # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) batch_size, num_image = images.shape[:2] images = images.to(device=self.device) @@ -1878,8 +1877,8 @@ class MolmoAct2RotaryEmbedding(nn.Module): def __init__( self, config: MolmoAct2TextConfig, - device: Union[str, torch.device] = None, - rope_type: Optional[str] = None, + device: str | torch.device = None, + rope_type: str | None = None, ): super().__init__() if rope_type is not None: @@ -1906,7 +1905,7 @@ class MolmoAct2RotaryEmbedding(nn.Module): @staticmethod def _default_rope_init( - config: MolmoAct2TextConfig, device: Union[str, torch.device] = None, **_ + config: MolmoAct2TextConfig, device: str | torch.device = None, **_ ) -> tuple[torch.Tensor, float]: inv_freq = 1.0 / ( config.rope_theta @@ -1914,7 +1913,7 @@ class MolmoAct2RotaryEmbedding(nn.Module): ) return inv_freq, 1.0 - def _target_cache_seq_len(self, x: torch.Tensor, position_ids: Optional[torch.Tensor]) -> int: + def _target_cache_seq_len(self, x: torch.Tensor, position_ids: torch.Tensor | None) -> int: if self.config.max_position_embeddings: return int(self.config.max_position_embeddings) if position_ids is not None: @@ -1967,8 +1966,8 @@ class MolmoAct2RotaryEmbedding(nn.Module): def prepare_rope_cache( self, *, - device: Union[str, torch.device], - max_seq_len: Optional[int] = None, + device: str | torch.device, + max_seq_len: int | None = None, ) -> None: if self.rope_type != "default": return @@ -1984,7 +1983,7 @@ class MolmoAct2RotaryEmbedding(nn.Module): def _select_rope_cache( self, x: torch.Tensor, - position_ids: Optional[torch.Tensor], + position_ids: torch.Tensor | None, seq_len: int, ) -> tuple[torch.Tensor, torch.Tensor]: pos_sin = self._pos_sin_cache[:, :, :seq_len, :] @@ -2012,7 +2011,7 @@ class MolmoAct2RMSNorm(nn.Module): self, size: int, eps: float = 1e-6, - device: Union[str, torch.device] = None, + device: str | torch.device = None, ): super().__init__() self.weight = nn.Parameter(torch.ones(size, device=device)) @@ -2050,11 +2049,11 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, scaling: float, dropout: float = 0.0, **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None]: key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -2097,9 +2096,9 @@ class MolmoAct2Attention(nn.Module): ) # Layer norms. - self.k_norm: Optional[MolmoAct2RMSNorm] = None - self.q_norm: Optional[MolmoAct2RMSNorm] = None - self.qk_norm_type: Optional[str] = None + self.k_norm: MolmoAct2RMSNorm | None = None + self.q_norm: MolmoAct2RMSNorm | None = None + self.qk_norm_type: str | None = None if config.use_qk_norm: k_norm_size = ( config.head_dim @@ -2127,11 +2126,11 @@ class MolmoAct2Attention(nn.Module): self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -2212,7 +2211,7 @@ class LanguageModelMLP(nn.Module): input_dim: int, intermediate_size: int, hidden_act: str, - device: Union[str, torch.device] = None, + device: str | torch.device = None, ): super().__init__() self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device) @@ -2231,8 +2230,8 @@ class MolmoAct2DecoderLayer(GradientCheckpointingLayer): def __init__( self, config: MolmoAct2TextConfig, - layer_idx: Optional[int] = None, - device: Union[str, torch.device] = None, + layer_idx: int | None = None, + device: str | torch.device = None, ): super().__init__() self.config = config @@ -2252,14 +2251,14 @@ class MolmoAct2DecoderLayer(GradientCheckpointingLayer): self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) residual = hidden_states @@ -2305,14 +2304,14 @@ class MolmoAct2PostNormDecoderLayer(MolmoAct2DecoderLayer): self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: collect_layer_kv_states = bool(kwargs.pop("collect_layer_kv_states", False)) residual = hidden_states @@ -2359,7 +2358,7 @@ class MolmoAct2Embedding(nn.Module): num_embeddings: int, num_new_embeddings: int, features: int, - device: Union[str, torch.device] = None, + device: str | torch.device = None, ): super().__init__() self.embedding = nn.Parameter( @@ -2453,8 +2452,8 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): def prepare_rope_cache( self, *, - device: Union[str, torch.device], - max_seq_len: Optional[int] = None, + device: str | torch.device, + max_seq_len: int | None = None, ) -> None: if self.config.rope_scaling_layers is not None: for rotary_emb in self.rotary_embs.values(): @@ -2471,15 +2470,15 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): @can_return_tuple def forward( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: output_attentions = ( @@ -2607,8 +2606,8 @@ class MolmoAct2TextModel(MolmoAct2PreTrainedModel): # Adapted from transformers.models.gemma3.modeling_gemma3 def token_type_ids_mask_function( - token_type_ids: Optional[torch.Tensor] = None, -) -> Optional[Callable]: + token_type_ids: torch.Tensor | None = None, +) -> Callable | None: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. @@ -2643,7 +2642,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def __init__(self, config: MolmoAct2Config): super().__init__(config) self.transformer: MolmoAct2TextModel = MolmoAct2TextModel(config.text_config) - self.vision_backbone: Optional[MolmoAct2VisionBackbone] = None + self.vision_backbone: MolmoAct2VisionBackbone | None = None if config.vit_config is not None and config.adapter_config is not None: self.vision_backbone = MolmoAct2VisionBackbone(config.vit_config, config.adapter_config) llm_kv_dim = config.text_config.num_key_value_heads * config.text_config.head_dim @@ -2667,7 +2666,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): else: self.action_expert_depth_gate = None self._depth_gate_token_ids = self._resolve_depth_gate_token_ids() - self.action_cuda_graph_manager: Optional[ActionCudaGraphManager] = None + self.action_cuda_graph_manager: ActionCudaGraphManager | None = None # Initialize weights and apply final processing self.post_init() @@ -2700,7 +2699,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): nn.init.zeros_(gate.weight) nn.init.constant_(gate.bias, float(self.config.action_expert_depth_gate_init_bias)) - def _resolve_depth_gate_token_ids(self) -> Tuple[int, ...]: + def _resolve_depth_gate_token_ids(self) -> tuple[int, ...]: if not self.config.action_expert_depth_gate: return () token_ids = [] @@ -2740,7 +2739,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): bsz, seq_len, n_heads, head_dim = cache.shape return cache.reshape(bsz, seq_len, n_heads * head_dim) - def _extract_kv_states(self, past_key_values: Cache) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]: + def _extract_kv_states(self, past_key_values: Cache) -> Sequence[tuple[torch.Tensor, torch.Tensor]]: if past_key_values is None: raise RuntimeError("Action generation requires past_key_values from the VLM forward pass.") seq_len = _cache_seq_len_int(past_key_values) @@ -2762,8 +2761,8 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def _mask_discrete_output_span( row_ids: torch.Tensor, row_mask: torch.Tensor, - start_id: Optional[int], - end_id: Optional[int], + start_id: int | None, + end_id: int | None, ) -> None: if start_id is None or end_id is None: return @@ -2784,9 +2783,9 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def _get_encoder_attention_mask( self, - input_ids: Optional[torch.Tensor], - attention_mask: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: + input_ids: torch.Tensor | None, + attention_mask: torch.Tensor | None, + ) -> torch.Tensor | None: if attention_mask is not None: mask = attention_mask.to(dtype=torch.bool).clone() elif input_ids is not None: @@ -2809,9 +2808,9 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def _get_depth_token_mask( self, - input_ids: Optional[torch.Tensor], - encoder_attention_mask: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: + input_ids: torch.Tensor | None, + encoder_attention_mask: torch.Tensor | None, + ) -> torch.Tensor | None: if not self.config.action_expert_depth_gate or input_ids is None or not self._depth_gate_token_ids: return None depth_token_ids = torch.as_tensor( @@ -2830,7 +2829,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): *, source: torch.Tensor, depth_mask: torch.Tensor, - encoder_attention_mask: Optional[torch.Tensor], + encoder_attention_mask: torch.Tensor | None, ) -> torch.Tensor: if source.ndim == 4: source = source.reshape(source.shape[0], source.shape[1], -1) @@ -2852,10 +2851,10 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def _depth_gate_from_condition( self, *, - input_ids: Optional[torch.Tensor], - encoder_attention_mask: Optional[torch.Tensor], - layer_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]], - ) -> Tuple[Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], Optional[torch.Tensor]]: + input_ids: torch.Tensor | None, + encoder_attention_mask: torch.Tensor | None, + layer_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] | None, + ) -> tuple[torch.Tensor | Sequence[torch.Tensor] | None, torch.Tensor | None]: gate_head = self.action_expert_depth_gate if gate_head is None: return None, None @@ -2888,7 +2887,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): @staticmethod def _depth_gate_for_layer( - gate: Union[torch.Tensor, Sequence[torch.Tensor]], + gate: torch.Tensor | Sequence[torch.Tensor], layer_idx: int, *, num_layers: int, @@ -2901,10 +2900,10 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def _apply_depth_gate_to_layer_kv_states( self, - layer_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]], - depth_mask: Optional[torch.Tensor], - gate: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], - ) -> Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]]: + layer_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] | None, + depth_mask: torch.Tensor | None, + gate: torch.Tensor | Sequence[torch.Tensor] | None, + ) -> Sequence[tuple[torch.Tensor, torch.Tensor]] | None: if layer_kv_states is None or depth_mask is None or gate is None: return layer_kv_states gated_kv = [] @@ -2924,8 +2923,8 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): @staticmethod def _action_dim_valid_mask( target: torch.Tensor, - action_dim_is_pad: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: + action_dim_is_pad: torch.Tensor | None, + ) -> torch.Tensor | None: if action_dim_is_pad is None: return None mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool) @@ -2950,7 +2949,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): cls, tensor: torch.Tensor, *, - action_dim_is_pad: Optional[torch.Tensor], + action_dim_is_pad: torch.Tensor | None, enabled: bool, ) -> torch.Tensor: if not enabled: @@ -2986,7 +2985,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): ) return trajectory - def _resolve_action_horizon(self, action_horizon: Optional[int] = None) -> int: + def _resolve_action_horizon(self, action_horizon: int | None = None) -> int: max_action_horizon = int(self.config.max_action_horizon or 1) resolved = max_action_horizon if action_horizon is None else int(action_horizon) if resolved < 1: @@ -3002,22 +3001,22 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): self, *, input_ids: torch.LongTensor, - pixel_values: Optional[torch.Tensor] = None, - image_token_pooling: Optional[torch.Tensor] = None, - image_grids: Optional[torch.Tensor] = None, - image_num_crops: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.Tensor] = None, - video_token_pooling: Optional[torch.Tensor] = None, - video_grids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - states: Optional[torch.Tensor] = None, - action_dim_is_pad: Optional[torch.Tensor] = None, - action_horizon: Optional[int] = None, - num_steps: Optional[int] = None, - generator: Optional[torch.Generator] = None, - encoder_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.LongTensor | None = None, + states: torch.Tensor | None = None, + action_dim_is_pad: torch.Tensor | None = None, + action_horizon: int | None = None, + num_steps: int | None = None, + generator: torch.Generator | None = None, + encoder_kv_states: Sequence[tuple[torch.Tensor, torch.Tensor]] | None = None, + encoder_attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: action_expert = self._require_action_expert() if encoder_kv_states is None: @@ -3267,7 +3266,6 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): video_token_pooling: torch.Tensor, video_grids: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - # 1) Count the number of videos in each example if self.config.use_frame_special_tokens: end_token_id = self.config.frame_end_token_id @@ -3370,15 +3368,15 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def merge_visual_inputs( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.Tensor] = None, - image_token_pooling: Optional[torch.Tensor] = None, - image_grids: Optional[torch.Tensor] = None, - image_num_crops: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.Tensor] = None, - video_token_pooling: Optional[torch.Tensor] = None, - video_grids: Optional[torch.Tensor] = None, - ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + input_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: if pixel_values is not None and pixel_values_videos is not None: raise ValueError("pixel_values and pixel_values_videos are provided at the same time") elif pixel_values is not None: @@ -3405,16 +3403,15 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): def build_input_embeddings( self, input_ids: torch.LongTensor, - images: Optional[torch.FloatTensor] = None, # image inputs - token_pooling: Optional[torch.LongTensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - + images: torch.FloatTensor | None = None, # image inputs + token_pooling: torch.LongTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: # Get embeddings of input. # shape: (batch_size, seq_len, d_model) input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) x = self.transformer.wte(input_ids) - image_features: Optional[torch.FloatTensor] = None + image_features: torch.FloatTensor | None = None if images is not None: image_features = self.vision_backbone(images, token_pooling).to(x.device) is_image_patch = input_ids.reshape(-1) == self.config.image_patch_id @@ -3435,9 +3432,9 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): self, *, inputs_embeds: torch.Tensor, - attention_mask: Optional[torch.Tensor], - token_type_ids: Optional[torch.Tensor], - past_key_values: Optional[Cache], + attention_mask: torch.Tensor | None, + token_type_ids: torch.Tensor | None, + past_key_values: Cache | None, ) -> torch.Tensor: if attention_mask is not None and attention_mask.ndim == 4: return attention_mask.to(device=inputs_embeds.device) @@ -3492,26 +3489,25 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel): @can_return_tuple def forward( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - image_token_pooling: Optional[torch.Tensor] = None, - image_grids: Optional[torch.Tensor] = None, - image_num_crops: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.Tensor] = None, - video_token_pooling: Optional[torch.Tensor] = None, - video_grids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[Cache] = None, - token_type_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: Cache | None = None, + token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> Union[tuple, MolmoAct2ModelOutputWithPast]: - + ) -> tuple | MolmoAct2ModelOutputWithPast: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) @@ -3613,7 +3609,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi def get_decoder(self): return self.model.get_decoder() - # Make modules available throught conditional class for BC + # Make modules available through conditional class for BC @property def language_model(self) -> torch.nn.Module: return self.model.transformer @@ -3644,21 +3640,21 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi f"MolmoAct2 normalization stats file is missing: {stats_path}. " "Converted checkpoints must include norm_stats.json." ) from exc - with open(stats_path, "r", encoding="utf-8") as f: + with open(stats_path, encoding="utf-8") as f: payload = json.load(f) stats = _RobotStats(payload) self._molmoact2_robot_stats = stats return stats @staticmethod - def _move_inputs_to_device(inputs: Mapping[str, Any], device: torch.device) -> Dict[str, Any]: + def _move_inputs_to_device(inputs: Mapping[str, Any], device: torch.device) -> dict[str, Any]: out = {} for key, value in inputs.items(): out[key] = value.to(device) if torch.is_tensor(value) else value return out @staticmethod - def _drop_trivial_attention_mask(inputs: Mapping[str, Any]) -> Dict[str, Any]: + def _drop_trivial_attention_mask(inputs: Mapping[str, Any]) -> dict[str, Any]: out = dict(inputs) attention_mask = out.get("attention_mask") if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()): @@ -3683,7 +3679,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi max_action_dim: int, batch_size: int, device: torch.device, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: if int(action_dim) > int(max_action_dim): raise ValueError( f"Requested action_dim {int(action_dim)} exceeds checkpoint max_action_dim {int(max_action_dim)}." @@ -3704,7 +3700,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi @staticmethod def _slice_action_chunk( - actions: torch.Tensor, n_obs_steps: int, n_action_steps: Optional[int] + actions: torch.Tensor, n_obs_steps: int, n_action_steps: int | None ) -> torch.Tensor: if n_action_steps is None: return actions @@ -3714,13 +3710,13 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi raise ValueError(f"Requested actions up to {end} but model produced horizon {actions.shape[1]}.") return actions[:, start:end] - def _depth_token_id_to_bin(self) -> Dict[int, int]: + def _depth_token_id_to_bin(self) -> dict[int, int]: if self.config.depth_token_start_id is None or int(self.config.num_depth_tokens or 0) <= 0: return {} start = int(self.config.depth_token_start_id) return {start + idx: idx for idx in range(int(self.config.num_depth_tokens))} - def _action_token_id_to_bin(self) -> Dict[int, int]: + def _action_token_id_to_bin(self) -> dict[int, int]: if self.config.action_token_start_id is None or int(self.config.num_action_tokens or 0) <= 0: return {} start = int(self.config.action_token_start_id) @@ -3758,9 +3754,9 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi self, token_ids: torch.Tensor, *, - past_key_values: Optional[Cache], - attention_mask: Optional[torch.Tensor], - ) -> Tuple[MolmoAct2CausalLMOutputWithPast, Optional[torch.Tensor]]: + past_key_values: Cache | None, + attention_mask: torch.Tensor | None, + ) -> tuple[MolmoAct2CausalLMOutputWithPast, torch.Tensor | None]: if token_ids.ndim == 1: next_input_ids = token_ids.unsqueeze(1) elif token_ids.ndim == 2: @@ -3842,7 +3838,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi *, past_key_values: Cache, attention_bias: torch.Tensor, - ) -> Tuple[torch.Tensor, Cache]: + ) -> tuple[torch.Tensor, Cache]: if token_ids.ndim == 1: next_input_ids = token_ids.unsqueeze(1) elif token_ids.ndim == 2: @@ -3882,7 +3878,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi *, past_key_values: Cache, attention_bias: torch.Tensor, - ) -> Tuple[torch.Tensor, Cache]: + ) -> tuple[torch.Tensor, Cache]: return self._run_ar_decode_step( token_ids, past_key_values=past_key_values, @@ -3920,13 +3916,13 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi self, initial_output: MolmoAct2CausalLMOutputWithPast, *, - past_key_values: Optional[Cache], - attention_mask: Optional[torch.Tensor], + past_key_values: Cache | None, + attention_mask: torch.Tensor | None, end_token_id: int, max_steps: int, - attention_bias: Optional[torch.Tensor] = None, + attention_bias: torch.Tensor | None = None, ) -> torch.Tensor: - generated_tokens: List[torch.Tensor] = [] + generated_tokens: list[torch.Tensor] = [] current_output = initial_output current_past_key_values = past_key_values current_attention_mask = attention_mask @@ -3966,8 +3962,8 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi self, inputs: Mapping[str, Any], *, - latest_first_image: Optional[np.ndarray], - depth_cache: Optional[Mapping[str, Any]], + latest_first_image: np.ndarray | None, + depth_cache: Mapping[str, Any] | None, enable_adaptive_depth: bool, ) -> _DepthPrefix: if self.config.depth_start_token_id is None or self.config.depth_end_token_id is None: @@ -3982,7 +3978,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi current_output = output current_past_key_values = output.past_key_values current_attention_mask = inputs.get("attention_mask") - generated_tokens: List[torch.Tensor] = [] + generated_tokens: list[torch.Tensor] = [] if not enable_adaptive_depth: hit_depth_end = False @@ -4213,18 +4209,18 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi task: str, state: Any, norm_tag: str, - inference_action_mode: Optional[str] = None, + inference_action_mode: str | None = None, enable_depth_reasoning: bool = False, enable_adaptive_depth: bool = True, - depth_cache: Optional[Mapping[str, Any]] = None, + depth_cache: Mapping[str, Any] | None = None, action_tokenizer: Any = None, - num_steps: Optional[int] = None, - n_action_steps: Optional[int] = None, - generator: Optional[torch.Generator] = None, + num_steps: int | None = None, + n_action_steps: int | None = None, + generator: torch.Generator | None = None, normalize_language: bool = True, enable_cuda_graph: bool = True, return_dict: bool = True, - ) -> Union[MolmoAct2ActionOutput, torch.Tensor]: + ) -> MolmoAct2ActionOutput | torch.Tensor: if state is None: raise ValueError("MolmoAct2 `predict_action` requires `state` for discrete state prompting.") if inference_action_mode is None: @@ -4438,26 +4434,26 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi def forward( self, input_ids: torch.LongTensor = None, - pixel_values: Optional[torch.Tensor] = None, - image_token_pooling: Optional[torch.Tensor] = None, - image_grids: Optional[torch.Tensor] = None, - image_num_crops: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.Tensor] = None, - video_token_pooling: Optional[torch.Tensor] = None, - video_grids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, + pixel_values: torch.Tensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> Union[tuple, MolmoAct2CausalLMOutputWithPast]: + ) -> tuple | MolmoAct2CausalLMOutputWithPast: r""" ```python >>> from PIL import Image @@ -4524,22 +4520,21 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - image_token_pooling: Optional[torch.Tensor] = None, - image_grids: Optional[torch.Tensor] = None, - image_num_crops: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.Tensor] = None, - video_token_pooling: Optional[torch.Tensor] = None, - video_grids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Optional[Union[int, torch.Tensor]] = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_token_pooling: torch.Tensor | None = None, + image_grids: torch.Tensor | None = None, + image_num_crops: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + video_token_pooling: torch.Tensor | None = None, + video_grids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor | None = None, **kwargs, ): - model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -4570,11 +4565,11 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi def create_masks_for_generate( config: PretrainedConfig, input_embeds: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, cache_position: torch.Tensor, - past_key_values: Optional[Cache], - position_ids: Optional[torch.Tensor], - token_type_ids: Optional[torch.Tensor] = None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + token_type_ids: torch.Tensor | None = None, **kwargs, ) -> dict: # Prepare mask arguments diff --git a/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py index e01284bc8..7b8775faa 100644 --- a/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py @@ -102,12 +102,12 @@ class MolmoAct2Processor(ProcessorMixin): image_processor: MolmoAct2ImageProcessor = None, video_processor: MolmoAct2VideoProcessor = None, tokenizer: AutoTokenizer = None, - chat_template: Optional[str] = None, - image_use_col_tokens: Optional[bool] = True, - use_single_crop_col_tokens: Optional[bool] = None, - use_single_crop_start_token: Optional[bool] = True, - video_use_col_tokens: Optional[bool] = False, - use_frame_special_tokens: Optional[bool] = True, + chat_template: str | None = None, + image_use_col_tokens: bool | None = True, + use_single_crop_col_tokens: bool | None = None, + use_single_crop_start_token: bool | None = True, + video_use_col_tokens: bool | None = False, + use_frame_special_tokens: bool | None = True, **kwargs, ) -> None: super().__init__( @@ -272,7 +272,7 @@ class MolmoAct2Processor(ProcessorMixin): def __call__( self, - text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, images: ImageInput = None, videos: VideoInput = None, **kwargs: Unpack[MolmoAct2ProcessorKwargs], diff --git a/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py b/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py index 01763370a..644d5a691 100644 --- a/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py +++ b/src/lerobot/policies/molmoact2/hf_model/video_processing_molmoact2.py @@ -24,7 +24,8 @@ import warnings from contextlib import redirect_stdout from io import BytesIO from urllib.parse import urlparse -from typing import Optional, Union, Callable +from typing import Optional, Union +from collections.abc import Callable import numpy as np import requests @@ -224,9 +225,9 @@ def image_to_patches_and_grids( def get_candidate_target_fps( - video_fps: Union[int, float], - sampling_fps: Union[int, float], - max_fps: Union[int, float] = MAX_VIDEO_FPS, + video_fps: int | float, + sampling_fps: int | float, + max_fps: int | float = MAX_VIDEO_FPS, ) -> list[float]: """ Return the subset of `video_fps` factors that remain multiples of `sampling_fps`. @@ -468,7 +469,7 @@ VIDEO_DECODERS = { def load_video( video: VideoInput, backend: str = "decord", - sample_timestamps_fn: Optional[Callable] = None, + sample_timestamps_fn: Callable | None = None, **kwargs, ): """ @@ -502,7 +503,7 @@ def load_video( bytes_obj = buffer.getvalue() file_obj = BytesIO(bytes_obj) elif video.startswith("http://") or video.startswith("https://"): - file_obj = BytesIO(requests.get(video).content) + file_obj = BytesIO(requests.get(video, timeout=10).content) elif os.path.isfile(video): file_obj = video else: @@ -579,11 +580,11 @@ def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False): - patch_size: Optional[int] - pooling_size: Optional[list[int]] - frame_sample_mode: Optional[str] - max_fps: Optional[int] - sampling_fps: Optional[int] + patch_size: int | None + pooling_size: list[int] | None + frame_sample_mode: str | None + max_fps: int | None + sampling_fps: int | None class MolmoAct2VideoProcessor(BaseVideoProcessor): @@ -613,7 +614,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): def _further_process_kwargs( self, - size: Optional[SizeDict] = None, + size: SizeDict | None = None, **kwargs, ) -> dict: """ @@ -630,8 +631,8 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): metadata: VideoMetadata, frame_sample_mode: str, num_frames: int, - max_fps: Optional[int] = None, - sampling_fps: Optional[int] = None, + max_fps: int | None = None, + sampling_fps: int | None = None, **kwargs, ) -> np.ndarray: """ @@ -683,10 +684,10 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): def sample_frames( self, metadata: VideoMetadata, - frame_sample_mode: Optional[str] = None, - num_frames: Optional[int] = None, - max_fps: Optional[int] = None, - sampling_fps: Optional[int] = None, + frame_sample_mode: str | None = None, + num_frames: int | None = None, + max_fps: int | None = None, + sampling_fps: int | None = None, **kwargs, ) -> np.ndarray: """ @@ -761,9 +762,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): else: raise NotImplementedError(frame_sample_mode) - def fetch_videos( - self, video_url_or_urls: Union[str, list[str], list[list[str]]], sample_timestamps_fn=None - ): + def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None): """ Convert a single or a list of urls into the corresponding `np.array` objects. @@ -805,10 +804,10 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): def _decode_and_sample_videos( self, videos: VideoInput, - video_metadata: Union[VideoMetadata, dict], - do_sample_frames: Optional[bool] = None, - sample_indices_fn: Optional[Callable] = None, - sample_timestamps_fn: Optional[Callable] = None, + video_metadata: VideoMetadata | dict, + do_sample_frames: bool | None = None, + sample_indices_fn: Callable | None = None, + sample_timestamps_fn: Callable | None = None, ): """ Decode input videos and sample frames if needed. @@ -890,14 +889,14 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor): def _preprocess( self, videos: list[np.ndarray], - size: Optional[SizeDict] = None, - resample: Optional[PILImageResampling] = None, - image_mean: Optional[Union[float, list[float]]] = None, - image_std: Optional[Union[float, list[float]]] = None, - do_convert_rgb: Optional[bool] = None, - patch_size: Optional[int] = None, - pooling_size: Optional[list[int]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + size: SizeDict | None = None, + resample: PILImageResampling | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool | None = None, + patch_size: int | None = None, + pooling_size: list[int] | None = None, + return_tensors: str | TensorType | None = None, **kwargs, ) -> BatchFeature: """