mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
remove flash-attn requirement && fix bug in inference and fast mode
This commit is contained in:
committed by
Michel Aractingi
parent
d2b16afb12
commit
fc6262e23d
+1
-2
@@ -131,8 +131,7 @@ wallx = [
|
|||||||
"peft==0.17.1",
|
"peft==0.17.1",
|
||||||
"scipy==1.15.3",
|
"scipy==1.15.3",
|
||||||
"torchdiffeq==0.2.5",
|
"torchdiffeq==0.2.5",
|
||||||
"qwen_vl_utils==0.0.11",
|
"qwen_vl_utils==0.0.11"
|
||||||
"flash-attn==2.7.4.post1"
|
|
||||||
]
|
]
|
||||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||||
|
|||||||
@@ -53,13 +53,15 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
# Pretrained model paths
|
# Pretrained model paths
|
||||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||||
|
|
||||||
|
# Tokenizer settings
|
||||||
|
action_tokenizer_path: str | None = "physical-intelligence/fast"
|
||||||
|
|
||||||
# Action prediction mode: "diffusion" or "fast"
|
# Action prediction mode: "diffusion" or "fast"
|
||||||
prediction_mode: str = "diffusion"
|
prediction_mode: str = "diffusion"
|
||||||
|
|
||||||
# Tokenizer settings
|
# Attention Implementation, options: "eager", "flash_attention_2", "sdpa"
|
||||||
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
|
# NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation
|
||||||
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
|
attn_implementation: str = "eager"
|
||||||
|
|
||||||
|
|
||||||
# ==================== Optimizer Presets ====================
|
# ==================== Optimizer Presets ====================
|
||||||
optimizer_lr: float = 2e-5
|
optimizer_lr: float = 2e-5
|
||||||
@@ -87,11 +89,16 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sync prediction_mode with use_fast_tokenizer
|
# Assign use_fast_tokenizer based on prediction_mode
|
||||||
if self.use_fast_tokenizer:
|
if self.prediction_mode == "fast":
|
||||||
self.prediction_mode = "fast"
|
self.use_fast_tokenizer = True
|
||||||
|
elif self.prediction_mode == "diffusion":
|
||||||
|
self.use_fast_tokenizer = False
|
||||||
|
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
||||||
else:
|
else:
|
||||||
self.prediction_mode = "diffusion"
|
raise ValueError(
|
||||||
|
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate and set up input/output features."""
|
"""Validate and set up input/output features."""
|
||||||
|
|||||||
@@ -18,8 +18,6 @@
|
|||||||
Wall-X Constants and Configuration Data.
|
Wall-X Constants and Configuration Data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION
|
|
||||||
|
|
||||||
CAMERA_NAME_MAPPING = {
|
CAMERA_NAME_MAPPING = {
|
||||||
"face_view": "front view",
|
"face_view": "front view",
|
||||||
"left_wrist_view": "left wrist view",
|
"left_wrist_view": "left wrist view",
|
||||||
|
|||||||
@@ -63,8 +63,22 @@ from lerobot.policies.utils import populate_queues
|
|||||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
|
||||||
from lerobot.policies.wall_x.utils import *
|
from lerobot.policies.wall_x.utils import (
|
||||||
from lerobot.policies.wall_x.constant import *
|
replace_action_token,
|
||||||
|
preprocesser_call,
|
||||||
|
get_wallx_normal_text,
|
||||||
|
process_grounding_points,
|
||||||
|
)
|
||||||
|
from lerobot.policies.wall_x.constant import (
|
||||||
|
MODEL_TYPE,
|
||||||
|
TOKENIZER_MAX_LENGTH,
|
||||||
|
PRIORITY_ORDER,
|
||||||
|
GENERATE_SUBTASK_RATIO,
|
||||||
|
RESOLUTION,
|
||||||
|
MAX_PIXELS,
|
||||||
|
MIN_PIXELS,
|
||||||
|
IMAGE_FACTOR,
|
||||||
|
)
|
||||||
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
@@ -261,6 +275,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
config=None,
|
config=None,
|
||||||
action_tokenizer_path=None,
|
action_tokenizer_path=None,
|
||||||
|
attn_implementation: str = 'eager',
|
||||||
cache_dir: str | PathLike | None = None,
|
cache_dir: str | PathLike | None = None,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
@@ -276,6 +291,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
pretrained_model_path (str): Model directory path containing model.safetensors file
|
pretrained_model_path (str): Model directory path containing model.safetensors file
|
||||||
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
|
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
|
||||||
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
|
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
|
||||||
|
attn_implementation (str, optional): Attention implementation, if None will load from default config
|
||||||
**kwargs: Additional arguments
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -292,14 +308,18 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
strict=strict,
|
strict=strict,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if attn_implementation is not None:
|
||||||
|
config._attn_implementation = attn_implementation
|
||||||
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
|
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
|
||||||
if action_tokenizer_path is not None:
|
if action_tokenizer_path is not None:
|
||||||
processor.action_processor = AutoProcessor.from_pretrained(
|
action_tokenizer = AutoProcessor.from_pretrained(
|
||||||
action_tokenizer_path, trust_remote_code=True
|
action_tokenizer_path, trust_remote_code=True
|
||||||
)
|
)
|
||||||
|
processor.action_processor = action_tokenizer
|
||||||
|
else:
|
||||||
|
action_tokenizer = None
|
||||||
# Initialize model with configuration and processor
|
# Initialize model with configuration and processor
|
||||||
model = cls(config, processor=processor, **kwargs)
|
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
|
||||||
|
|
||||||
# Resize token embeddings to match processor tokenizer vocabulary size
|
# Resize token embeddings to match processor tokenizer vocabulary size
|
||||||
model.resize_token_embeddings(len(processor.tokenizer))
|
model.resize_token_embeddings(len(processor.tokenizer))
|
||||||
@@ -379,6 +399,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
self.flow_loss_weight = flow_loss_weight
|
self.flow_loss_weight = flow_loss_weight
|
||||||
self.use_fast_tokenizer = use_fast_tokenizer
|
self.use_fast_tokenizer = use_fast_tokenizer
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
self.action_tokenizer = action_tokenizer
|
||||||
|
|
||||||
# Define action token IDs
|
# Define action token IDs
|
||||||
self.define_action_token_id()
|
self.define_action_token_id()
|
||||||
@@ -1279,7 +1300,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = labels[:, split_pos + 3 :]
|
labels = labels[:, split_pos + 3 :]
|
||||||
else:
|
else:
|
||||||
raise Warning(
|
raise ValueError(
|
||||||
"input_ids does not contain the generation prompt tokens <|im_start|>assistant"
|
"input_ids does not contain the generation prompt tokens <|im_start|>assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1826,7 +1847,7 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Initialize the wall-x model
|
# Initialize the wall-x model
|
||||||
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
|
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation)
|
||||||
self.model.to(config.device)
|
self.model.to(config.device)
|
||||||
self.model.to_bfloat16_for_selected_params()
|
self.model.to_bfloat16_for_selected_params()
|
||||||
|
|
||||||
@@ -1950,7 +1971,8 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
], dim=-1)
|
], dim=-1)
|
||||||
|
|
||||||
# ==================== PROCESS ACTIONS ====================
|
# ==================== PROCESS ACTIONS ====================
|
||||||
action = batch[ACTION] # (batch_size, chunk_size, action_dim)
|
action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim)
|
||||||
|
if action is not None:
|
||||||
if action.dim() == 2:
|
if action.dim() == 2:
|
||||||
action = action.unsqueeze(1)
|
action = action.unsqueeze(1)
|
||||||
dof_mask = (~torch.isnan(action)).float()
|
dof_mask = (~torch.isnan(action)).float()
|
||||||
|
|||||||
@@ -231,11 +231,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
|||||||
self.attention_moe = attention_moe
|
self.attention_moe = attention_moe
|
||||||
self.mlp_moe = mlp_moe
|
self.mlp_moe = mlp_moe
|
||||||
|
|
||||||
# Validate the correctness of rotary position embeddings parameters
|
|
||||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
|
||||||
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
|
||||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
|
||||||
# TODO: @raushan update config in the hub
|
|
||||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
if self.rope_scaling["type"] == "mrope":
|
if self.rope_scaling["type"] == "mrope":
|
||||||
self.rope_scaling["type"] = "default"
|
self.rope_scaling["type"] = "default"
|
||||||
|
|||||||
@@ -1506,7 +1506,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
The device to plcae the 4D attention mask on.
|
The device to place the 4D attention mask on.
|
||||||
cache_position (`torch.Tensor`):
|
cache_position (`torch.Tensor`):
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
# Wall-X policy tests
|
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ from lerobot.policies.factory import make_policy_config # noqa: E402
|
|||||||
from lerobot.policies.wall_x import ( # noqa: E402
|
from lerobot.policies.wall_x import ( # noqa: E402
|
||||||
WallXConfig,
|
WallXConfig,
|
||||||
WallXPolicy,
|
WallXPolicy,
|
||||||
make_wall_x_pre_post_processors, # noqa: E402
|
make_wall_x_pre_post_processors,
|
||||||
)
|
)
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user