From fc6262e23db07ca8c924e43337626770579fdf4b Mon Sep 17 00:00:00 2001 From: Geoffrey19 Date: Wed, 10 Dec 2025 22:40:05 +0800 Subject: [PATCH] remove flash-attn requirement && fix bug in inference and fast mode --- pyproject.toml | 3 +- .../policies/wall_x/configuration_wall_x.py | 23 ++++--- src/lerobot/policies/wall_x/constant.py | 2 - .../policies/wall_x/modeling_wall_x.py | 68 ++++++++++++------- .../qwen_model/configuration_qwen2_5_vl.py | 5 -- .../wall_x/qwen_model/qwen2_5_vl_moe.py | 2 +- tests/policies/wall_x/__init__.py | 2 - tests/policies/wall_x/test_wallx.py | 2 +- 8 files changed, 63 insertions(+), 44 deletions(-) delete mode 100644 tests/policies/wall_x/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 57e00a952..f04e0dabf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,8 +131,7 @@ wallx = [ "peft==0.17.1", "scipy==1.15.3", "torchdiffeq==0.2.5", - "qwen_vl_utils==0.0.11", - "flash-attn==2.7.4.post1" + "qwen_vl_utils==0.0.11" ] pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index a4d67eb0f..456ac993e 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -53,13 +53,15 @@ class WallXConfig(PreTrainedConfig): # Pretrained model paths pretrained_name_or_path: str = "x-square-robot/wall-oss-flow" + # Tokenizer settings + action_tokenizer_path: str | None = "physical-intelligence/fast" + # Action prediction mode: "diffusion" or "fast" prediction_mode: str = "diffusion" - # Tokenizer settings - use_fast_tokenizer: bool = False # True: train FAST, False: train Flow - action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode) - + # Attention Implementation, options: "eager", "flash_attention_2", "sdpa" + # NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation + attn_implementation: str = "eager" # ==================== Optimizer Presets ==================== optimizer_lr: float = 2e-5 @@ -87,11 +89,16 @@ class WallXConfig(PreTrainedConfig): f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}" ) - # Sync prediction_mode with use_fast_tokenizer - if self.use_fast_tokenizer: - self.prediction_mode = "fast" + # Assign use_fast_tokenizer based on prediction_mode + if self.prediction_mode == "fast": + self.use_fast_tokenizer = True + elif self.prediction_mode == "diffusion": + self.use_fast_tokenizer = False + self.action_tokenizer_path = None # disable action tokenizer for diffusion mode else: - self.prediction_mode = "diffusion" + raise ValueError( + f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}" + ) def validate_features(self) -> None: """Validate and set up input/output features.""" diff --git a/src/lerobot/policies/wall_x/constant.py b/src/lerobot/policies/wall_x/constant.py index a894cef06..597d24951 100644 --- a/src/lerobot/policies/wall_x/constant.py +++ b/src/lerobot/policies/wall_x/constant.py @@ -18,8 +18,6 @@ Wall-X Constants and Configuration Data. """ -from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION - CAMERA_NAME_MAPPING = { "face_view": "front view", "left_wrist_view": "left wrist view", diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index a99995bef..15a162c78 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -63,8 +63,22 @@ from lerobot.policies.utils import populate_queues from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.utils.constants import ACTION, OBS_STATE -from lerobot.policies.wall_x.utils import * -from lerobot.policies.wall_x.constant import * +from lerobot.policies.wall_x.utils import ( + replace_action_token, + preprocesser_call, + get_wallx_normal_text, + process_grounding_points, +) +from lerobot.policies.wall_x.constant import ( + MODEL_TYPE, + TOKENIZER_MAX_LENGTH, + PRIORITY_ORDER, + GENERATE_SUBTASK_RATIO, + RESOLUTION, + MAX_PIXELS, + MIN_PIXELS, + IMAGE_FACTOR, +) from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration, @@ -261,6 +275,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): pretrained_name_or_path, config=None, action_tokenizer_path=None, + attn_implementation: str = 'eager', cache_dir: str | PathLike | None = None, force_download: bool = False, local_files_only: bool = False, @@ -276,6 +291,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): pretrained_model_path (str): Model directory path containing model.safetensors file config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config + attn_implementation (str, optional): Attention implementation, if None will load from default config **kwargs: Additional arguments Returns: @@ -292,14 +308,18 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): strict=strict, **kwargs, ) + if attn_implementation is not None: + config._attn_implementation = attn_implementation processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True) if action_tokenizer_path is not None: - processor.action_processor = AutoProcessor.from_pretrained( + action_tokenizer = AutoProcessor.from_pretrained( action_tokenizer_path, trust_remote_code=True ) - + processor.action_processor = action_tokenizer + else: + action_tokenizer = None # Initialize model with configuration and processor - model = cls(config, processor=processor, **kwargs) + model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs) # Resize token embeddings to match processor tokenizer vocabulary size model.resize_token_embeddings(len(processor.tokenizer)) @@ -379,6 +399,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): self.flow_loss_weight = flow_loss_weight self.use_fast_tokenizer = use_fast_tokenizer self.processor = processor + self.action_tokenizer = action_tokenizer # Define action token IDs self.define_action_token_id() @@ -1279,7 +1300,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): if labels is not None: labels = labels[:, split_pos + 3 :] else: - raise Warning( + raise ValueError( "input_ids does not contain the generation prompt tokens <|im_start|>assistant" ) @@ -1826,7 +1847,7 @@ class WallXPolicy(PreTrainedPolicy): self.config = config # Initialize the wall-x model - self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path) + self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation) self.model.to(config.device) self.model.to_bfloat16_for_selected_params() @@ -1950,22 +1971,23 @@ class WallXPolicy(PreTrainedPolicy): ], dim=-1) # ==================== PROCESS ACTIONS ==================== - action = batch[ACTION] # (batch_size, chunk_size, action_dim) - if action.dim() == 2: - action = action.unsqueeze(1) - dof_mask = (~torch.isnan(action)).float() - action = action.nan_to_num(nan=0.0) - - if action.shape[-1] != 20: - pad_size = 20 - action.shape[-1] - action = torch.cat([ - action, - torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device) - ], dim=-1) - dof_mask = torch.cat([ - dof_mask, - torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device) - ], dim=-1) + action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim) + if action is not None: + if action.dim() == 2: + action = action.unsqueeze(1) + dof_mask = (~torch.isnan(action)).float() + action = action.nan_to_num(nan=0.0) + + if action.shape[-1] != 20: + pad_size = 20 - action.shape[-1] + action = torch.cat([ + action, + torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device) + ], dim=-1) + dof_mask = torch.cat([ + dof_mask, + torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device) + ], dim=-1) # ==================== ACTION TOKEN REPLACEMENT ==================== all_texts = replace_action_token( diff --git a/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py index 439a36923..e9efc1b05 100644 --- a/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py +++ b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py @@ -231,11 +231,6 @@ class Qwen2_5_VLConfig(PretrainedConfig): self.attention_moe = attention_moe self.mlp_moe = mlp_moe - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. - # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations - # one can set it to "linear"/"dynamic" etc. to have scaled RoPE - # TODO: @raushan update config in the hub if self.rope_scaling is not None and "type" in self.rope_scaling: if self.rope_scaling["type"] == "mrope": self.rope_scaling["type"] = "default" diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index 438ac044c..a21b0b348 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -1506,7 +1506,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): - The device to plcae the 4D attention mask on. + The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): diff --git a/tests/policies/wall_x/__init__.py b/tests/policies/wall_x/__init__.py deleted file mode 100644 index 7f5f042a0..000000000 --- a/tests/policies/wall_x/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Wall-X policy tests - diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py index ac8ee59da..2440ec98b 100644 --- a/tests/policies/wall_x/test_wallx.py +++ b/tests/policies/wall_x/test_wallx.py @@ -31,7 +31,7 @@ from lerobot.policies.factory import make_policy_config # noqa: E402 from lerobot.policies.wall_x import ( # noqa: E402 WallXConfig, WallXPolicy, - make_wall_x_pre_post_processors, # noqa: E402 + make_wall_x_pre_post_processors, ) from lerobot.utils.random_utils import set_seed # noqa: E402