remove flash-attn requirement && fix bug in inference and fast mode

This commit is contained in:
Geoffrey19
2025-12-10 22:40:05 +08:00
committed by Michel Aractingi
parent d2b16afb12
commit fc6262e23d
8 changed files with 63 additions and 44 deletions
+1 -2
View File
@@ -131,8 +131,7 @@ wallx = [
"peft==0.17.1",
"scipy==1.15.3",
"torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11",
"flash-attn==2.7.4.post1"
"qwen_vl_utils==0.0.11"
]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
@@ -53,13 +53,15 @@ class WallXConfig(PreTrainedConfig):
# Pretrained model paths
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
# Tokenizer settings
action_tokenizer_path: str | None = "physical-intelligence/fast"
# Action prediction mode: "diffusion" or "fast"
prediction_mode: str = "diffusion"
# Tokenizer settings
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
# Attention Implementation, options: "eager", "flash_attention_2", "sdpa"
# NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation
attn_implementation: str = "eager"
# ==================== Optimizer Presets ====================
optimizer_lr: float = 2e-5
@@ -87,11 +89,16 @@ class WallXConfig(PreTrainedConfig):
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
# Sync prediction_mode with use_fast_tokenizer
if self.use_fast_tokenizer:
self.prediction_mode = "fast"
# Assign use_fast_tokenizer based on prediction_mode
if self.prediction_mode == "fast":
self.use_fast_tokenizer = True
elif self.prediction_mode == "diffusion":
self.use_fast_tokenizer = False
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
else:
self.prediction_mode = "diffusion"
raise ValueError(
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
def validate_features(self) -> None:
"""Validate and set up input/output features."""
-2
View File
@@ -18,8 +18,6 @@
Wall-X Constants and Configuration Data.
"""
from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION
CAMERA_NAME_MAPPING = {
"face_view": "front view",
"left_wrist_view": "left wrist view",
+44 -22
View File
@@ -63,8 +63,22 @@ from lerobot.policies.utils import populate_queues
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.policies.wall_x.utils import *
from lerobot.policies.wall_x.constant import *
from lerobot.policies.wall_x.utils import (
replace_action_token,
preprocesser_call,
get_wallx_normal_text,
process_grounding_points,
)
from lerobot.policies.wall_x.constant import (
MODEL_TYPE,
TOKENIZER_MAX_LENGTH,
PRIORITY_ORDER,
GENERATE_SUBTASK_RATIO,
RESOLUTION,
MAX_PIXELS,
MIN_PIXELS,
IMAGE_FACTOR,
)
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
@@ -261,6 +275,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
pretrained_name_or_path,
config=None,
action_tokenizer_path=None,
attn_implementation: str = 'eager',
cache_dir: str | PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
@@ -276,6 +291,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
pretrained_model_path (str): Model directory path containing model.safetensors file
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
attn_implementation (str, optional): Attention implementation, if None will load from default config
**kwargs: Additional arguments
Returns:
@@ -292,14 +308,18 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
strict=strict,
**kwargs,
)
if attn_implementation is not None:
config._attn_implementation = attn_implementation
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
if action_tokenizer_path is not None:
processor.action_processor = AutoProcessor.from_pretrained(
action_tokenizer = AutoProcessor.from_pretrained(
action_tokenizer_path, trust_remote_code=True
)
processor.action_processor = action_tokenizer
else:
action_tokenizer = None
# Initialize model with configuration and processor
model = cls(config, processor=processor, **kwargs)
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
# Resize token embeddings to match processor tokenizer vocabulary size
model.resize_token_embeddings(len(processor.tokenizer))
@@ -379,6 +399,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
self.flow_loss_weight = flow_loss_weight
self.use_fast_tokenizer = use_fast_tokenizer
self.processor = processor
self.action_tokenizer = action_tokenizer
# Define action token IDs
self.define_action_token_id()
@@ -1279,7 +1300,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
if labels is not None:
labels = labels[:, split_pos + 3 :]
else:
raise Warning(
raise ValueError(
"input_ids does not contain the generation prompt tokens <|im_start|>assistant"
)
@@ -1826,7 +1847,7 @@ class WallXPolicy(PreTrainedPolicy):
self.config = config
# Initialize the wall-x model
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation)
self.model.to(config.device)
self.model.to_bfloat16_for_selected_params()
@@ -1950,22 +1971,23 @@ class WallXPolicy(PreTrainedPolicy):
], dim=-1)
# ==================== PROCESS ACTIONS ====================
action = batch[ACTION] # (batch_size, chunk_size, action_dim)
if action.dim() == 2:
action = action.unsqueeze(1)
dof_mask = (~torch.isnan(action)).float()
action = action.nan_to_num(nan=0.0)
action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim)
if action is not None:
if action.dim() == 2:
action = action.unsqueeze(1)
dof_mask = (~torch.isnan(action)).float()
action = action.nan_to_num(nan=0.0)
if action.shape[-1] != 20:
pad_size = 20 - action.shape[-1]
action = torch.cat([
action,
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
], dim=-1)
dof_mask = torch.cat([
dof_mask,
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
], dim=-1)
if action.shape[-1] != 20:
pad_size = 20 - action.shape[-1]
action = torch.cat([
action,
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
], dim=-1)
dof_mask = torch.cat([
dof_mask,
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
], dim=-1)
# ==================== ACTION TOKEN REPLACEMENT ====================
all_texts = replace_action_token(
@@ -231,11 +231,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
self.attention_moe = attention_moe
self.mlp_moe = mlp_moe
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
@@ -1506,7 +1506,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
-2
View File
@@ -1,2 +0,0 @@
# Wall-X policy tests
+1 -1
View File
@@ -31,7 +31,7 @@ from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.wall_x import ( # noqa: E402
WallXConfig,
WallXPolicy,
make_wall_x_pre_post_processors, # noqa: E402
make_wall_x_pre_post_processors,
)
from lerobot.utils.random_utils import set_seed # noqa: E402