update the policy methods

This commit is contained in:
Geoffrey19
2025-12-02 12:11:29 +08:00
committed by Michel Aractingi
parent a8e7a2967c
commit feebca050a
2 changed files with 676 additions and 558 deletions
+220 -558
View File
@@ -34,6 +34,7 @@ lerobot-train \
```
"""
import builtins
import glob
import math
import os
@@ -70,7 +71,8 @@ from transformers.utils import is_torchdynamo_compiling, logging
from transformers import AutoProcessor, BatchFeature
from qwen_vl_utils.vision_process import smart_resize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
@@ -91,8 +93,11 @@ from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLDecoderLayer_with_MoE,
Qwen2_5_VLACausalLMOutputWithPast,
Qwen2_5_VLMoEModel,
)
logger = logging.get_logger(__name__)
# Add wall-x repo to path if available
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
if WALL_X_PATH.exists():
@@ -257,461 +262,6 @@ class ActionHead(nn.Module):
return self.propri_proj(proprioception)
class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel):
"""Qwen2.5-VL model with Mixture of Experts (MoE) architecture.
This model extends the base Qwen2.5-VL model by incorporating MoE layers
for improved scalability and specialization across different token types.
"""
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
num_experts: Optional[int] = None,
*args,
**kwargs,
):
"""Load a pretrained model with optional MoE configuration.
Args:
pretrained_model_name_or_path: Path or name of the pretrained model
num_experts: Number of experts for MoE layers (if not in config)
*args: Additional arguments passed to parent class
**kwargs: Additional keyword arguments passed to parent class
Returns:
Initialized model instance with MoE configuration
"""
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
# Override number of experts if specified
if num_experts is not None:
config.num_experts = num_experts
kwargs["config"] = config
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
def __init__(self, config: Qwen2_5_VLConfig):
"""Initialize the Qwen2.5-VL MoE model.
Args:
config: Model configuration containing architecture parameters
"""
super().__init__(config)
# Basic model parameters
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# Model components
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
# Decoder layers with MoE support
self.layers = nn.ModuleList(
[
Qwen2_5_VLDecoderLayer_with_MoE(config, layer_idx, config.num_experts)
for layer_idx in range(config.num_hidden_layers)
]
)
# Model configuration
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
"""Get the input embedding layer.
Returns:
The token embedding layer
"""
return self.embed_tokens
def set_input_embeddings(self, value: nn.Embedding) -> None:
"""Set the input embedding layer.
Args:
value: New embedding layer to use
"""
self.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
moe_token_types: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
# Set default output options
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Validate inputs
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if moe_token_types is None:
raise ValueError("moe_token_types must be provided for MoE routing")
# Handle gradient checkpointing compatibility
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Initialize cache if needed
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
# Get input embeddings
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Set up cache position
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Set up position IDs (hardcoded 3 dimensions for temporal, height, width)
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(
3, inputs_embeds.shape[0], -1
)
elif position_ids.dim() == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
# Create causal attention mask
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
moe_token_types,
)
hidden_states = inputs_embeds
# Create position embeddings to be shared across decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Initialize output collections
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# Process through decoder layers
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
# Use gradient checkpointing during training
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
moe_token_types,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
# Regular forward pass
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
token_types=moe_token_types,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
# Update cache if using it
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
# Collect attention weights if requested
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Apply final layer normalization
hidden_states = self.norm(hidden_states)
# Add final hidden states if collecting all states
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
# Return outputs in requested format
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
moe_token_types: Optional[torch.LongTensor] = None,
):
"""Update causal attention mask with support for bidirectional attention for specific token types.
This method creates and modifies attention masks to support different attention patterns:
- Standard causal (unidirectional) attention for most tokens
- Bidirectional attention for specific token types (e.g., MoE routing tokens)
Args:
attention_mask: Input attention mask to avoid attending to padding tokens
input_tensor: Input embeddings tensor for shape and device information
cache_position: Position indices for caching mechanisms
past_key_values: Cached key-value pairs from previous forward passes
output_attentions: Whether attention weights will be returned
moe_token_types: Optional tensor indicating token types for MoE routing
(type 1 tokens will use bidirectional attention)
Returns:
Updated causal attention mask, or None if using Flash Attention 2
"""
# Flash Attention 2 handles masking internally
if self.config._attn_implementation == "flash_attention_2":
return None
# Calculate sequence lengths for cache management
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# For SDPA (Scaled Dot Product Attention), use `is_causal` argument when possible
# instead of explicit attention mask to enable Flash Attention 2 dispatch
# Note: This optimization is not compatible with static cache
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
# Check if we can ignore the causal mask and rely on SDPA's internal handling
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
# Extract tensor properties for mask creation
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# Determine target length based on cache type
if using_sliding_window_cache or using_static_cache:
# Use maximum cache shape for sliding window or static caches
target_length = past_key_values.get_max_cache_shape()
else:
# For dynamic cache or no cache, calculate based on attention mask or sequence length
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# Generate 4D causal attention mask from 2D input mask if provided
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
# Modify mask to support bidirectional attention for specific token types
if moe_token_types is not None:
# Identify positions of type 1 tokens (MoE routing tokens)
type1_tokens = (
(moe_token_types == 1).unsqueeze(1).unsqueeze(2)
) # Shape: [B, 1, 1, S]
# Create bidirectional attention region for type 1 tokens
# This allows type 1 tokens to attend to each other bidirectionally
type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S]
type1_region = type1_tokens & type1_tokens.transpose(
-1, -2
) # Shape: [B, 1, S, S]
type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool)
# Apply bidirectional attention: zero out causal constraints in type 1 regions
causal_mask = torch.where(
type1_mask, # Where type 1 tokens interact with each other
torch.zeros_like(
causal_mask
), # Remove causal masking (allow bidirectional)
causal_mask, # Keep original causal masking for other regions
)
# Handle special case for SDPA with CUDA/XPU devices
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Ensure attention to all tokens in fully masked rows for memory-efficient attention
# This is required for F.scaled_dot_product_attention's memory-efficient path
# when using left padding. See: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2_5_VLConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen2_5_VLConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
diagonal_attend_mask = torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if (
not isinstance(past_key_values, SlidingWindowCache)
or sequence_length > target_length
):
sliding_attend_mask = torch.arange(
target_length, device=device
) <= (cache_position.reshape(-1, 1) - config.sliding_window)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"""
Qwen2.5 Vision-Language Mixture of Experts model for action processing.
@@ -2321,6 +1871,145 @@ class WallXPolicy(PreTrainedPolicy):
self.reset()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
) -> T:
"""
Load WallXPolicy from a pretrained model path.
Args:
pretrained_name_or_path: Path to pretrained model or model identifier
config: Optional configuration object
force_download: Force download even if cached
resume_download: Resume interrupted download
proxies: Proxy configuration
token: Authentication token
cache_dir: Cache directory path
local_files_only: Only use local files
revision: Model revision
strict: Strict loading of state dict
**kwargs: Additional arguments
Returns:
WallXPolicy: Loaded policy instance
"""
print(
"Loading Wall-X model for cross-embodiment robotic control.\n"
"This implementation integrates Qwen2.5-VL with flow matching for action prediction."
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Use provided config if available, otherwise load from pretrained path
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
# Initialize model without loading weights
model = cls(config, **kwargs)
# Load and remap the state dict
try:
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
original_state_dict = load_file(resolved_file)
print("✓ Loaded state dict from model.safetensors")
except Exception:
print(f"Could not load state dict: {e}")
print("Returning model without loading pretrained weights")
return model
# Filter out normalizer statistics if present
filtered_state_dict = {}
for key, value in original_state_dict.items():
if "action_preprocessor.normalizer" not in key:
filtered_state_dict[key] = value
else:
print(f"Filtered key: {key}")
# Add "model." prefix for keys that don't have it
remapped_state_dict = {}
remap_count = 0
for key, value in filtered_state_dict.items():
if not key.startswith("model."):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
else:
remapped_state_dict[key] = value
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
return model
def reset(self):
"""Reset action queue."""
self._queues = {
@@ -2329,25 +2018,13 @@ class WallXPolicy(PreTrainedPolicy):
def get_optim_params(self):
"""Get parameters for optimization."""
params = []
if self.model.visual.available:
# Add VLM parameters
if not self.config.train_expert_only:
params.extend(self.vlm.model.parameters())
# Always add action head parameters
if self.config.train_action_head:
params.extend(self.action_head.parameters())
return params
return self.parameters()
def preprocess_inputs(
self,
batch: List[Dict[str, Any]],
config: Dict[str, Any],
dataload_config: Dict[str, Any],
norm_stats: Dict[str, Any],
lerobot_config: Dict[str, Any],
processor: Any,
action_tokenizer: Optional[Any] = None,
@@ -2473,25 +2150,12 @@ class WallXPolicy(PreTrainedPolicy):
all_actions.append(data[action_key])
all_frame_indices.append(frame_index)
# ==================== BATCH NORMALIZATION ====================
action_min_stat = norm_stats["action"].min
action_delta = norm_stats["action"].delta
state_min_stat = norm_stats["state"].min
state_delta = norm_stats["state"].delta
def normalize(x, min_stat, delta):
delta = torch.where(delta == 0, torch.ones_like(delta), delta)
x = (x - min_stat) / delta
x = x * 2 - 1
return torch.clamp(x, -1, 1)
# Stack and normalize agent_pos
# Stack agent_pos
agent_pos = torch.stack(all_agent_pos)
if agent_pos.dim() == 2:
agent_pos = agent_pos.unsqueeze(1)
agent_pos_mask = (~torch.isnan(agent_pos)).float()
agent_pos = agent_pos.nan_to_num(nan=0.0)
agent_pos = normalize(agent_pos, state_min_stat, state_delta)
if agent_pos.shape[-1] != 20:
pad_size = 20 - agent_pos.shape[-1]
@@ -2504,13 +2168,12 @@ class WallXPolicy(PreTrainedPolicy):
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size)
], dim=-1)
# Stack and normalize actions
# Stack actions
action = torch.stack(all_actions)
if action.dim() == 2:
action = action.unsqueeze(1)
dof_mask = (~torch.isnan(action)).float()
action = action.nan_to_num(nan=0.0)
action = normalize(action, action_min_stat, action_delta)
if action.shape[-1] != 20:
pad_size = 20 - action.shape[-1]
@@ -2563,87 +2226,63 @@ class WallXPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""
Training forward pass.
Training forward pass using Qwen2_5_VLMoEForAction.
Args:
batch: Dictionary containing observations and actions
batch: Dictionary containing preprocessed inputs from preprocess_inputs()
Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw,
proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types,
dataset_names, etc.
Returns:
tuple: (loss, loss_dict)
"""
# Prepare inputs
images = self.prepare_images(batch)
state = self.prepare_state(batch)
actions = self.prepare_action(batch)
batch_size = actions.shape[0]
device = actions.device
dtype = actions.dtype
# Create DOF mask
dof_mask = self._create_dof_mask(batch_size, device, dtype)
# Process actions through action head (adds noise, gets embeddings)
action_embeds, flow_target = self.action_head(actions, dof_mask)
# For now, use simplified loss computation
# In full implementation, would pass through VLM transformer
loss_dict = {}
# Compute flow matching loss
# Note: In full wall-x, action_embeds would go through VLM transformer first
flow_loss = self.action_head.flow_loss(action_embeds, flow_target, dof_mask)
loss = flow_loss.mean()
loss_dict["loss"] = loss.item()
loss_dict["flow_loss"] = loss.item()
return loss, loss_dict
def _sample_actions_flow(self, batch: dict[str, Tensor]) -> Tensor:
"""
Sample actions using flow matching / diffusion.
Args:
batch: Dictionary containing observations
Returns:
Predicted actions [batch, chunk_size, action_dim]
"""
batch_size = 1 # Typically inference is single sample
device = self.config.device
dtype = torch.float32
# Initialize with noise
noisy_action = torch.randn(
batch_size,
self.config.chunk_size,
sum(self.config.dof_config.values()),
device=device,
dtype=dtype
batch = self.preprocess_inputs(
batch,
self.config,
self.dataload_config,
self.lerobot_config,
self.processor,
self.action_tokenizer,
self.camera_keys
)
# Create DOF mask
dof_mask = self._create_dof_mask(batch_size, device, dtype)
# Call the underlying model's forward with mode="train"
outputs = self.model(
mode="train",
input_ids=batch.get("input_ids"),
attention_mask=batch.get("attention_mask"),
pixel_values=batch.get("pixel_values"),
image_grid_thw=batch.get("image_grid_thw"),
pixel_values_videos=batch.get("pixel_values_videos"),
video_grid_thw=batch.get("video_grid_thw"),
proprioception=batch.get("proprioception"),
agent_pos_mask=batch.get("agent_pos_mask"),
action_chunk=batch.get("action_chunk"),
dof_mask=batch.get("dof_mask"),
moe_token_types=batch.get("moe_token_types"),
dataset_names=batch.get("dataset_names"),
labels=batch.get("labels", batch.get("input_ids")), # Use input_ids as labels if not provided
)
# ODE integration for denoising
num_steps = self.config.num_inference_timesteps
dt = 1.0 / num_steps
# Extract losses from output
loss = outputs.loss
loss_dict = {
"loss": loss.item() if loss is not None else 0.0,
}
for step_idx in range(num_steps + 1):
t = torch.tensor(step_idx * dt, device=device, dtype=dtype)
timestep = t.unsqueeze(0).repeat(batch_size)
if outputs.flow_loss is not None:
loss_dict["flow_loss"] = outputs.flow_loss.item()
if outputs.cross_entropy_loss is not None:
loss_dict["cross_entropy_loss"] = outputs.cross_entropy_loss.item()
# Single denoising step
action_embeds = self.action_head.step(timestep, noisy_action, dof_mask)
# Add channel losses if available
if outputs.channel_loss_dict is not None:
for key, value in outputs.channel_loss_dict.items():
if isinstance(value, torch.Tensor):
loss_dict[f"channel_{key}"] = value.item()
# Predict flow (in full implementation, would go through VLM)
flow_pred = self.action_head.action_proj_back(action_embeds)
# Euler integration step
noisy_action = noisy_action + dt * flow_pred
return noisy_action
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
@@ -2651,14 +2290,37 @@ class WallXPolicy(PreTrainedPolicy):
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if self.config.prediction_mode == "flow":
actions = self._sample_actions_flow(batch)
batch = self.preprocess_inputs(
batch,
self.config,
self.dataload_config,
self.lerobot_config,
self.processor,
self.action_tokenizer,
self.camera_keys
)
if self.config.prediction_mode == "diffusion":
actions = self.model(
**batch,
action_dim=self.config.max_action_dim,
pred_horizon=self.config.chunk_size,
mode="predict",
predict_mode="diffusion"
)
elif self.config.prediction_mode == "fast":
actions = self.model(
**batch,
action_dim=self.config.action_feature.shape[0],
pred_horizon=self.config.chunk_size,
mode="predict",
predict_mode="fast"
)
else:
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = actions[:, :, :self.config.action_feature.shape[0]]
return actions
@@ -13,6 +13,7 @@ from transformers.cache_utils import (
SlidingWindowCache,
StaticCache,
)
from transformers import AutoConfig
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
@@ -2540,9 +2541,464 @@ class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module):
outputs += (present_key_value,)
return outputs
class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel):
"""Qwen2.5-VL model with Mixture of Experts (MoE) architecture.
This model extends the base Qwen2.5-VL model by incorporating MoE layers
for improved scalability and specialization across different token types.
"""
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
num_experts: Optional[int] = None,
*args,
**kwargs,
):
"""Load a pretrained model with optional MoE configuration.
Args:
pretrained_model_name_or_path: Path or name of the pretrained model
num_experts: Number of experts for MoE layers (if not in config)
*args: Additional arguments passed to parent class
**kwargs: Additional keyword arguments passed to parent class
Returns:
Initialized model instance with MoE configuration
"""
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
# Override number of experts if specified
if num_experts is not None:
config.num_experts = num_experts
kwargs["config"] = config
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
def __init__(self, config: Qwen2_5_VLConfig):
"""Initialize the Qwen2.5-VL MoE model.
Args:
config: Model configuration containing architecture parameters
"""
super().__init__(config)
# Basic model parameters
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# Model components
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
# Decoder layers with MoE support
self.layers = nn.ModuleList(
[
Qwen2_5_VLDecoderLayer_with_MoE(config, layer_idx, config.num_experts)
for layer_idx in range(config.num_hidden_layers)
]
)
# Model configuration
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
"""Get the input embedding layer.
Returns:
The token embedding layer
"""
return self.embed_tokens
def set_input_embeddings(self, value: nn.Embedding) -> None:
"""Set the input embedding layer.
Args:
value: New embedding layer to use
"""
self.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
moe_token_types: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
# Set default output options
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Validate inputs
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if moe_token_types is None:
raise ValueError("moe_token_types must be provided for MoE routing")
# Handle gradient checkpointing compatibility
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Initialize cache if needed
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
# Get input embeddings
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Set up cache position
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Set up position IDs (hardcoded 3 dimensions for temporal, height, width)
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(
3, inputs_embeds.shape[0], -1
)
elif position_ids.dim() == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
# Create causal attention mask
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
moe_token_types,
)
hidden_states = inputs_embeds
# Create position embeddings to be shared across decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Initialize output collections
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# Process through decoder layers
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
# Use gradient checkpointing during training
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
moe_token_types,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
# Regular forward pass
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
token_types=moe_token_types,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
# Update cache if using it
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
# Collect attention weights if requested
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Apply final layer normalization
hidden_states = self.norm(hidden_states)
# Add final hidden states if collecting all states
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
# Return outputs in requested format
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
moe_token_types: Optional[torch.LongTensor] = None,
):
"""Update causal attention mask with support for bidirectional attention for specific token types.
This method creates and modifies attention masks to support different attention patterns:
- Standard causal (unidirectional) attention for most tokens
- Bidirectional attention for specific token types (e.g., MoE routing tokens)
Args:
attention_mask: Input attention mask to avoid attending to padding tokens
input_tensor: Input embeddings tensor for shape and device information
cache_position: Position indices for caching mechanisms
past_key_values: Cached key-value pairs from previous forward passes
output_attentions: Whether attention weights will be returned
moe_token_types: Optional tensor indicating token types for MoE routing
(type 1 tokens will use bidirectional attention)
Returns:
Updated causal attention mask, or None if using Flash Attention 2
"""
# Flash Attention 2 handles masking internally
if self.config._attn_implementation == "flash_attention_2":
return None
# Calculate sequence lengths for cache management
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# For SDPA (Scaled Dot Product Attention), use `is_causal` argument when possible
# instead of explicit attention mask to enable Flash Attention 2 dispatch
# Note: This optimization is not compatible with static cache
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
# Check if we can ignore the causal mask and rely on SDPA's internal handling
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
# Extract tensor properties for mask creation
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# Determine target length based on cache type
if using_sliding_window_cache or using_static_cache:
# Use maximum cache shape for sliding window or static caches
target_length = past_key_values.get_max_cache_shape()
else:
# For dynamic cache or no cache, calculate based on attention mask or sequence length
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# Generate 4D causal attention mask from 2D input mask if provided
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
# Modify mask to support bidirectional attention for specific token types
if moe_token_types is not None:
# Identify positions of type 1 tokens (MoE routing tokens)
type1_tokens = (
(moe_token_types == 1).unsqueeze(1).unsqueeze(2)
) # Shape: [B, 1, 1, S]
# Create bidirectional attention region for type 1 tokens
# This allows type 1 tokens to attend to each other bidirectionally
type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S]
type1_region = type1_tokens & type1_tokens.transpose(
-1, -2
) # Shape: [B, 1, S, S]
type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool)
# Apply bidirectional attention: zero out causal constraints in type 1 regions
causal_mask = torch.where(
type1_mask, # Where type 1 tokens interact with each other
torch.zeros_like(
causal_mask
), # Remove causal masking (allow bidirectional)
causal_mask, # Keep original causal masking for other regions
)
# Handle special case for SDPA with CUDA/XPU devices
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Ensure attention to all tokens in fully masked rows for memory-efficient attention
# This is required for F.scaled_dot_product_attention's memory-efficient path
# when using left padding. See: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2_5_VLConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen2_5_VLConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
diagonal_attend_mask = torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if (
not isinstance(past_key_values, SlidingWindowCache)
or sequence_length > target_length
):
sliding_attend_mask = torch.arange(
target_length, device=device
) <= (cache_position.reshape(-1, 1) - config.sliding_window)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
__all__ = [
"Qwen2_5_VLForConditionalGeneration",
"Qwen2_5_VLModel",
"Qwen2_5_VLPreTrainedModel",
"Qwen2_5_VLDecoderLayer_with_MoE",
"Qwen2_5_VLMoEModel",
]