diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index 458028c3a..bebba4a27 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -34,6 +34,7 @@ lerobot-train \ ``` """ +import builtins import glob import math import os @@ -70,7 +71,8 @@ from transformers.utils import is_torchdynamo_compiling, logging from transformers import AutoProcessor, BatchFeature from qwen_vl_utils.vision_process import smart_resize -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.utils import populate_queues from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE @@ -91,8 +93,11 @@ from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLDecoderLayer_with_MoE, Qwen2_5_VLACausalLMOutputWithPast, + Qwen2_5_VLMoEModel, ) +logger = logging.get_logger(__name__) + # Add wall-x repo to path if available WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x") if WALL_X_PATH.exists(): @@ -257,461 +262,6 @@ class ActionHead(nn.Module): return self.propri_proj(proprioception) -class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): - """Qwen2.5-VL model with Mixture of Experts (MoE) architecture. - - This model extends the base Qwen2.5-VL model by incorporating MoE layers - for improved scalability and specialization across different token types. - """ - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: str, - num_experts: Optional[int] = None, - *args, - **kwargs, - ): - """Load a pretrained model with optional MoE configuration. - - Args: - pretrained_model_name_or_path: Path or name of the pretrained model - num_experts: Number of experts for MoE layers (if not in config) - *args: Additional arguments passed to parent class - **kwargs: Additional keyword arguments passed to parent class - - Returns: - Initialized model instance with MoE configuration - """ - config = kwargs.get("config", None) - if config is None: - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - - # Override number of experts if specified - if num_experts is not None: - config.num_experts = num_experts - - kwargs["config"] = config - return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - - def __init__(self, config: Qwen2_5_VLConfig): - """Initialize the Qwen2.5-VL MoE model. - - Args: - config: Model configuration containing architecture parameters - """ - super().__init__(config) - - # Basic model parameters - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - # Model components - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) - - # Decoder layers with MoE support - self.layers = nn.ModuleList( - [ - Qwen2_5_VLDecoderLayer_with_MoE(config, layer_idx, config.num_experts) - for layer_idx in range(config.num_hidden_layers) - ] - ) - - # Model configuration - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - """Get the input embedding layer. - - Returns: - The token embedding layer - """ - return self.embed_tokens - - def set_input_embeddings(self, value: nn.Embedding) -> None: - """Set the input embedding layer. - - Args: - value: New embedding layer to use - """ - self.embed_tokens = value - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - moe_token_types: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - # Set default output options - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # Validate inputs - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) - - if moe_token_types is None: - raise ValueError("moe_token_types must be provided for MoE routing") - - # Handle gradient checkpointing compatibility - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # Initialize cache if needed - if use_cache and past_key_values is None and not torch.jit.is_tracing(): - past_key_values = DynamicCache() - - # Get input embeddings - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # Set up cache position - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - # Set up position IDs (hardcoded 3 dimensions for temporal, height, width) - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand( - 3, inputs_embeds.shape[0], -1 - ) - elif position_ids.dim() == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - # Create causal attention mask - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, - moe_token_types, - ) - - hidden_states = inputs_embeds - - # Create position embeddings to be shared across decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # Initialize output collections - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - # Process through decoder layers - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - # Use gradient checkpointing during training - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - moe_token_types, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - # Regular forward pass - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - token_types=moe_token_types, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs[0] - - # Update cache if using it - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - # Collect attention weights if requested - if output_attentions: - all_self_attns += (layer_outputs[1],) - - # Apply final layer normalization - hidden_states = self.norm(hidden_states) - - # Add final hidden states if collecting all states - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - # Return outputs in requested format - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - moe_token_types: Optional[torch.LongTensor] = None, - ): - """Update causal attention mask with support for bidirectional attention for specific token types. - - This method creates and modifies attention masks to support different attention patterns: - - Standard causal (unidirectional) attention for most tokens - - Bidirectional attention for specific token types (e.g., MoE routing tokens) - - Args: - attention_mask: Input attention mask to avoid attending to padding tokens - input_tensor: Input embeddings tensor for shape and device information - cache_position: Position indices for caching mechanisms - past_key_values: Cached key-value pairs from previous forward passes - output_attentions: Whether attention weights will be returned - moe_token_types: Optional tensor indicating token types for MoE routing - (type 1 tokens will use bidirectional attention) - - Returns: - Updated causal attention mask, or None if using Flash Attention 2 - """ - # Flash Attention 2 handles masking internally - if self.config._attn_implementation == "flash_attention_2": - return None - - # Calculate sequence lengths for cache management - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # For SDPA (Scaled Dot Product Attention), use `is_causal` argument when possible - # instead of explicit attention mask to enable Flash Attention 2 dispatch - # Note: This optimization is not compatible with static cache - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - # Check if we can ignore the causal mask and rely on SDPA's internal handling - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - # Extract tensor properties for mask creation - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - - # Determine target length based on cache type - if using_sliding_window_cache or using_static_cache: - # Use maximum cache shape for sliding window or static caches - target_length = past_key_values.get_max_cache_shape() - else: - # For dynamic cache or no cache, calculate based on attention mask or sequence length - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # Generate 4D causal attention mask from 2D input mask if provided - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - # Modify mask to support bidirectional attention for specific token types - if moe_token_types is not None: - # Identify positions of type 1 tokens (MoE routing tokens) - type1_tokens = ( - (moe_token_types == 1).unsqueeze(1).unsqueeze(2) - ) # Shape: [B, 1, 1, S] - - # Create bidirectional attention region for type 1 tokens - # This allows type 1 tokens to attend to each other bidirectionally - type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S] - type1_region = type1_tokens & type1_tokens.transpose( - -1, -2 - ) # Shape: [B, 1, S, S] - type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool) - - # Apply bidirectional attention: zero out causal constraints in type 1 regions - causal_mask = torch.where( - type1_mask, # Where type 1 tokens interact with each other - torch.zeros_like( - causal_mask - ), # Remove causal masking (allow bidirectional) - causal_mask, # Keep original causal masking for other regions - ) - - # Handle special case for SDPA with CUDA/XPU devices - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and not output_attentions - ): - # Ensure attention to all tokens in fully masked rows for memory-efficient attention - # This is required for F.scaled_dot_product_attention's memory-efficient path - # when using left padding. See: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2_5_VLConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2_5_VLConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, - ) - diagonal_attend_mask = torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if ( - not isinstance(past_key_values, SlidingWindowCache) - or sequence_length > target_length - ): - sliding_attend_mask = torch.arange( - target_length, device=device - ) <= (cache_position.reshape(-1, 1) - config.sliding_window) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ - :, None, None, : - ].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) - return causal_mask - - class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): """ Qwen2.5 Vision-Language Mixture of Experts model for action processing. @@ -2321,6 +1871,145 @@ class WallXPolicy(PreTrainedPolicy): self.reset() + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = False, + **kwargs, + ) -> T: + """ + Load WallXPolicy from a pretrained model path. + + Args: + pretrained_name_or_path: Path to pretrained model or model identifier + config: Optional configuration object + force_download: Force download even if cached + resume_download: Resume interrupted download + proxies: Proxy configuration + token: Authentication token + cache_dir: Cache directory path + local_files_only: Only use local files + revision: Model revision + strict: Strict loading of state dict + **kwargs: Additional arguments + + Returns: + WallXPolicy: Loaded policy instance + """ + print( + "Loading Wall-X model for cross-embodiment robotic control.\n" + "This implementation integrates Qwen2.5-VL with flow matching for action prediction." + ) + + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise load from pretrained path + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + model = cls(config, **kwargs) + + # Load and remap the state dict + try: + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + ) + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception: + print(f"Could not load state dict: {e}") + print("Returning model without loading pretrained weights") + return model + + # Filter out normalizer statistics if present + filtered_state_dict = {} + for key, value in original_state_dict.items(): + if "action_preprocessor.normalizer" not in key: + filtered_state_dict[key] = value + else: + print(f"Filtered key: {key}") + + # Add "model." prefix for keys that don't have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in filtered_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not load state dict: {e}") + + return model + def reset(self): """Reset action queue.""" self._queues = { @@ -2329,25 +2018,13 @@ class WallXPolicy(PreTrainedPolicy): def get_optim_params(self): """Get parameters for optimization.""" - params = [] - - if self.model.visual.available: - # Add VLM parameters - if not self.config.train_expert_only: - params.extend(self.vlm.model.parameters()) - - # Always add action head parameters - if self.config.train_action_head: - params.extend(self.action_head.parameters()) - - return params + return self.parameters() def preprocess_inputs( self, batch: List[Dict[str, Any]], config: Dict[str, Any], dataload_config: Dict[str, Any], - norm_stats: Dict[str, Any], lerobot_config: Dict[str, Any], processor: Any, action_tokenizer: Optional[Any] = None, @@ -2473,25 +2150,12 @@ class WallXPolicy(PreTrainedPolicy): all_actions.append(data[action_key]) all_frame_indices.append(frame_index) - # ==================== BATCH NORMALIZATION ==================== - action_min_stat = norm_stats["action"].min - action_delta = norm_stats["action"].delta - state_min_stat = norm_stats["state"].min - state_delta = norm_stats["state"].delta - - def normalize(x, min_stat, delta): - delta = torch.where(delta == 0, torch.ones_like(delta), delta) - x = (x - min_stat) / delta - x = x * 2 - 1 - return torch.clamp(x, -1, 1) - - # Stack and normalize agent_pos + # Stack agent_pos agent_pos = torch.stack(all_agent_pos) if agent_pos.dim() == 2: agent_pos = agent_pos.unsqueeze(1) agent_pos_mask = (~torch.isnan(agent_pos)).float() agent_pos = agent_pos.nan_to_num(nan=0.0) - agent_pos = normalize(agent_pos, state_min_stat, state_delta) if agent_pos.shape[-1] != 20: pad_size = 20 - agent_pos.shape[-1] @@ -2504,13 +2168,12 @@ class WallXPolicy(PreTrainedPolicy): torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size) ], dim=-1) - # Stack and normalize actions + # Stack actions action = torch.stack(all_actions) if action.dim() == 2: action = action.unsqueeze(1) dof_mask = (~torch.isnan(action)).float() action = action.nan_to_num(nan=0.0) - action = normalize(action, action_min_stat, action_delta) if action.shape[-1] != 20: pad_size = 20 - action.shape[-1] @@ -2563,87 +2226,63 @@ class WallXPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """ - Training forward pass. + Training forward pass using Qwen2_5_VLMoEForAction. Args: - batch: Dictionary containing observations and actions + batch: Dictionary containing preprocessed inputs from preprocess_inputs() + Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw, + proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types, + dataset_names, etc. Returns: tuple: (loss, loss_dict) """ - # Prepare inputs - images = self.prepare_images(batch) - state = self.prepare_state(batch) - actions = self.prepare_action(batch) - - batch_size = actions.shape[0] - device = actions.device - dtype = actions.dtype - - # Create DOF mask - dof_mask = self._create_dof_mask(batch_size, device, dtype) - - # Process actions through action head (adds noise, gets embeddings) - action_embeds, flow_target = self.action_head(actions, dof_mask) - - # For now, use simplified loss computation - # In full implementation, would pass through VLM transformer - loss_dict = {} - - # Compute flow matching loss - # Note: In full wall-x, action_embeds would go through VLM transformer first - flow_loss = self.action_head.flow_loss(action_embeds, flow_target, dof_mask) - loss = flow_loss.mean() - - loss_dict["loss"] = loss.item() - loss_dict["flow_loss"] = loss.item() - - return loss, loss_dict - - def _sample_actions_flow(self, batch: dict[str, Tensor]) -> Tensor: - """ - Sample actions using flow matching / diffusion. - - Args: - batch: Dictionary containing observations - - Returns: - Predicted actions [batch, chunk_size, action_dim] - """ - batch_size = 1 # Typically inference is single sample - device = self.config.device - dtype = torch.float32 - - # Initialize with noise - noisy_action = torch.randn( - batch_size, - self.config.chunk_size, - sum(self.config.dof_config.values()), - device=device, - dtype=dtype + batch = self.preprocess_inputs( + batch, + self.config, + self.dataload_config, + self.lerobot_config, + self.processor, + self.action_tokenizer, + self.camera_keys ) - # Create DOF mask - dof_mask = self._create_dof_mask(batch_size, device, dtype) + # Call the underlying model's forward with mode="train" + outputs = self.model( + mode="train", + input_ids=batch.get("input_ids"), + attention_mask=batch.get("attention_mask"), + pixel_values=batch.get("pixel_values"), + image_grid_thw=batch.get("image_grid_thw"), + pixel_values_videos=batch.get("pixel_values_videos"), + video_grid_thw=batch.get("video_grid_thw"), + proprioception=batch.get("proprioception"), + agent_pos_mask=batch.get("agent_pos_mask"), + action_chunk=batch.get("action_chunk"), + dof_mask=batch.get("dof_mask"), + moe_token_types=batch.get("moe_token_types"), + dataset_names=batch.get("dataset_names"), + labels=batch.get("labels", batch.get("input_ids")), # Use input_ids as labels if not provided + ) - # ODE integration for denoising - num_steps = self.config.num_inference_timesteps - dt = 1.0 / num_steps + # Extract losses from output + loss = outputs.loss + loss_dict = { + "loss": loss.item() if loss is not None else 0.0, + } - for step_idx in range(num_steps + 1): - t = torch.tensor(step_idx * dt, device=device, dtype=dtype) - timestep = t.unsqueeze(0).repeat(batch_size) + if outputs.flow_loss is not None: + loss_dict["flow_loss"] = outputs.flow_loss.item() + if outputs.cross_entropy_loss is not None: + loss_dict["cross_entropy_loss"] = outputs.cross_entropy_loss.item() - # Single denoising step - action_embeds = self.action_head.step(timestep, noisy_action, dof_mask) + # Add channel losses if available + if outputs.channel_loss_dict is not None: + for key, value in outputs.channel_loss_dict.items(): + if isinstance(value, torch.Tensor): + loss_dict[f"channel_{key}"] = value.item() - # Predict flow (in full implementation, would go through VLM) - flow_pred = self.action_head.action_proj_back(action_embeds) - - # Euler integration step - noisy_action = noisy_action + dt * flow_pred - - return noisy_action + return loss, loss_dict @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: @@ -2651,14 +2290,37 @@ class WallXPolicy(PreTrainedPolicy): self.eval() self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - if self.config.prediction_mode == "flow": - actions = self._sample_actions_flow(batch) + batch = self.preprocess_inputs( + batch, + self.config, + self.dataload_config, + self.lerobot_config, + self.processor, + self.action_tokenizer, + self.camera_keys + ) + + if self.config.prediction_mode == "diffusion": + actions = self.model( + **batch, + action_dim=self.config.max_action_dim, + pred_horizon=self.config.chunk_size, + mode="predict", + predict_mode="diffusion" + ) + elif self.config.prediction_mode == "fast": + actions = self.model( + **batch, + action_dim=self.config.action_feature.shape[0], + pred_horizon=self.config.chunk_size, + mode="predict", + predict_mode="fast" + ) else: raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented") # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] + actions = actions[:, :, :self.config.action_feature.shape[0]] return actions diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index 9e8352ee6..438ac044c 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -13,6 +13,7 @@ from transformers.cache_utils import ( SlidingWindowCache, StaticCache, ) +from transformers import AutoConfig from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput @@ -2540,9 +2541,464 @@ class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): outputs += (present_key_value,) return outputs +class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): + """Qwen2.5-VL model with Mixture of Experts (MoE) architecture. + + This model extends the base Qwen2.5-VL model by incorporating MoE layers + for improved scalability and specialization across different token types. + """ + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + num_experts: Optional[int] = None, + *args, + **kwargs, + ): + """Load a pretrained model with optional MoE configuration. + + Args: + pretrained_model_name_or_path: Path or name of the pretrained model + num_experts: Number of experts for MoE layers (if not in config) + *args: Additional arguments passed to parent class + **kwargs: Additional keyword arguments passed to parent class + + Returns: + Initialized model instance with MoE configuration + """ + config = kwargs.get("config", None) + if config is None: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + + # Override number of experts if specified + if num_experts is not None: + config.num_experts = num_experts + + kwargs["config"] = config + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + def __init__(self, config: Qwen2_5_VLConfig): + """Initialize the Qwen2.5-VL MoE model. + + Args: + config: Model configuration containing architecture parameters + """ + super().__init__(config) + + # Basic model parameters + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Model components + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + + # Decoder layers with MoE support + self.layers = nn.ModuleList( + [ + Qwen2_5_VLDecoderLayer_with_MoE(config, layer_idx, config.num_experts) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # Model configuration + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + """Get the input embedding layer. + + Returns: + The token embedding layer + """ + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + """Set the input embedding layer. + + Args: + value: New embedding layer to use + """ + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + moe_token_types: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # Set default output options + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Validate inputs + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if moe_token_types is None: + raise ValueError("moe_token_types must be provided for MoE routing") + + # Handle gradient checkpointing compatibility + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Initialize cache if needed + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + # Get input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Set up cache position + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # Set up position IDs (hardcoded 3 dimensions for temporal, height, width) + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # Create causal attention mask + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + moe_token_types, + ) + + hidden_states = inputs_embeds + + # Create position embeddings to be shared across decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Process through decoder layers + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + # Use gradient checkpointing during training + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + moe_token_types, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + # Regular forward pass + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + token_types=moe_token_types, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + # Update cache if using it + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + # Collect attention weights if requested + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Apply final layer normalization + hidden_states = self.norm(hidden_states) + + # Add final hidden states if collecting all states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + # Return outputs in requested format + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + moe_token_types: Optional[torch.LongTensor] = None, + ): + """Update causal attention mask with support for bidirectional attention for specific token types. + + This method creates and modifies attention masks to support different attention patterns: + - Standard causal (unidirectional) attention for most tokens + - Bidirectional attention for specific token types (e.g., MoE routing tokens) + + Args: + attention_mask: Input attention mask to avoid attending to padding tokens + input_tensor: Input embeddings tensor for shape and device information + cache_position: Position indices for caching mechanisms + past_key_values: Cached key-value pairs from previous forward passes + output_attentions: Whether attention weights will be returned + moe_token_types: Optional tensor indicating token types for MoE routing + (type 1 tokens will use bidirectional attention) + + Returns: + Updated causal attention mask, or None if using Flash Attention 2 + """ + # Flash Attention 2 handles masking internally + if self.config._attn_implementation == "flash_attention_2": + return None + + # Calculate sequence lengths for cache management + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # For SDPA (Scaled Dot Product Attention), use `is_causal` argument when possible + # instead of explicit attention mask to enable Flash Attention 2 dispatch + # Note: This optimization is not compatible with static cache + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + # Check if we can ignore the causal mask and rely on SDPA's internal handling + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + # Extract tensor properties for mask creation + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + + # Determine target length based on cache type + if using_sliding_window_cache or using_static_cache: + # Use maximum cache shape for sliding window or static caches + target_length = past_key_values.get_max_cache_shape() + else: + # For dynamic cache or no cache, calculate based on attention mask or sequence length + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # Generate 4D causal attention mask from 2D input mask if provided + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + # Modify mask to support bidirectional attention for specific token types + if moe_token_types is not None: + # Identify positions of type 1 tokens (MoE routing tokens) + type1_tokens = ( + (moe_token_types == 1).unsqueeze(1).unsqueeze(2) + ) # Shape: [B, 1, 1, S] + + # Create bidirectional attention region for type 1 tokens + # This allows type 1 tokens to attend to each other bidirectionally + type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S] + type1_region = type1_tokens & type1_tokens.transpose( + -1, -2 + ) # Shape: [B, 1, S, S] + type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool) + + # Apply bidirectional attention: zero out causal constraints in type 1 regions + causal_mask = torch.where( + type1_mask, # Where type 1 tokens interact with each other + torch.zeros_like( + causal_mask + ), # Remove causal masking (allow bidirectional) + causal_mask, # Keep original causal masking for other regions + ) + + # Handle special case for SDPA with CUDA/XPU devices + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Ensure attention to all tokens in fully masked rows for memory-efficient attention + # This is required for F.scaled_dot_product_attention's memory-efficient path + # when using left padding. See: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if ( + not isinstance(past_key_values, SlidingWindowCache) + or sequence_length > target_length + ): + sliding_attend_mask = torch.arange( + target_length, device=device + ) <= (cache_position.reshape(-1, 1) - config.sliding_window) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + return causal_mask + __all__ = [ "Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLDecoderLayer_with_MoE", + "Qwen2_5_VLMoEModel", ] \ No newline at end of file