fix pre-commit errors

This commit is contained in:
Geoffrey19
2025-12-17 20:27:14 +08:00
committed by Michel Aractingi
parent 9ce6dd9e25
commit a0c9a7d85d
8 changed files with 577 additions and 955 deletions
+7 -7
View File
@@ -6,12 +6,13 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag
## Model Overview ## Model Overview
| Feature | Description | | Feature | Description |
| -------------------- | ------------------------------------------------------------------------ | | ------------------ | ----------------------------------------------------- | --- |
| Base Model | Qwen2.5-VL (Vision-Language Model) | | Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | | Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | | | Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | | Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
--- ---
## Citation ## Citation
@@ -32,4 +33,3 @@ If you use this work, please cite:
## License ## License
This port follows the **Apache 2.0 License**. This port follows the **Apache 2.0 License**.
@@ -49,7 +49,7 @@ class WallXConfig(PreTrainedConfig):
} }
) )
# ==================== Action Prediction ==================== # ==================== Action Prediction ====================
# Pretrained model paths # Pretrained model paths
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow" pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
@@ -85,20 +85,16 @@ class WallXConfig(PreTrainedConfig):
) )
if self.prediction_mode not in ["diffusion", "fast"]: if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError( raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
# Assign use_fast_tokenizer based on prediction_mode # Assign use_fast_tokenizer based on prediction_mode
if self.prediction_mode == "fast": if self.prediction_mode == "fast":
self.use_fast_tokenizer = True self.use_fast_tokenizer = True
elif self.prediction_mode == "diffusion": elif self.prediction_mode == "diffusion":
self.use_fast_tokenizer = False self.use_fast_tokenizer = False
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
else: else:
raise ValueError( raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate and set up input/output features.""" """Validate and set up input/output features."""
+1 -1
View File
@@ -38,4 +38,4 @@ PRIORITY_ORDER = None
GENERATE_SUBTASK_RATIO = 0.0 GENERATE_SUBTASK_RATIO = 0.0
MODEL_TYPE = "qwen2_5" MODEL_TYPE = "qwen2_5"
TOKENIZER_MAX_LENGTH = 768 TOKENIZER_MAX_LENGTH = 768
File diff suppressed because it is too large Load Diff
@@ -33,6 +33,8 @@ from lerobot.processor import (
) )
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors( def make_wall_x_pre_post_processors(
config: WallXConfig, config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors(
output_steps = [ output_steps = [
UnnormalizerProcessorStep( UnnormalizerProcessorStep(
features=config.output_features, features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
norm_map=config.normalization_mapping,
stats=dataset_stats
), ),
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
] ]
@@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
new_complementary_data["task"] = f"{task}." new_complementary_data["task"] = f"{task}."
elif isinstance(task, list) and all(isinstance(t, str) for t in task): elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: format each # List of strings: format each
new_complementary_data["task"] = [ new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
t if t.endswith(".") else f"{t}." for t in task
]
return new_complementary_data return new_complementary_data
File diff suppressed because it is too large Load Diff
+49 -80
View File
@@ -25,7 +25,7 @@ import random
import re import re
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any
import torch import torch
from transformers import BatchFeature from transformers import BatchFeature
@@ -46,11 +46,11 @@ class X2RDataProcessingConfig:
""" """
# Action prediction configuration # Action prediction configuration
predict_action_keys: List[str] = field(default_factory=list) predict_action_keys: list[str] = field(default_factory=list)
obs_action_keys: List[str] = field(default_factory=list) obs_action_keys: list[str] = field(default_factory=list)
# Image resolution settings for different views # Image resolution settings for different views
resolution: Dict[str, int] = field( resolution: dict[str, int] = field(
default_factory=lambda: { default_factory=lambda: {
"face_view": -1, "face_view": -1,
"left_wrist_view": 128, "left_wrist_view": 128,
@@ -63,7 +63,7 @@ class X2RDataProcessingConfig:
split_seed: int = 42 split_seed: int = 42
# Instruction handling # Instruction handling
priority_order: Optional[Dict[str, float]] = None priority_order: dict[str, float] | None = None
# Vision model parameters # Vision model parameters
model_type: str = "qwen2_5" model_type: str = "qwen2_5"
@@ -77,11 +77,9 @@ class X2RDataProcessingConfig:
"""Post-initialization validation and setup.""" """Post-initialization validation and setup."""
# Validate train/test split # Validate train/test split
if not 0 < self.train_test_split < 1: if not 0 < self.train_test_split < 1:
raise ValueError( raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
f"train_test_split must be between 0 and 1, got {self.train_test_split}"
)
def as_dict(self) -> Dict: def as_dict(self) -> dict:
"""Convert configuration to dictionary format. """Convert configuration to dictionary format.
Returns: Returns:
@@ -105,14 +103,15 @@ class X2RDataProcessingConfig:
raise ValueError(f"Unknown configuration parameter: {key}") raise ValueError(f"Unknown configuration parameter: {key}")
return self return self
def preprocesser_call( def preprocesser_call(
processor, processor,
images: Optional[Union[List, Any]] = None, images: list | Any | None = None,
text: Optional[Union[str, List[str]]] = None, text: str | list[str] | None = None,
videos: Optional[Union[List, Any]] = None, videos: list | Any | None = None,
padding: Union[bool, str] = False, padding: bool | str = False,
truncation: Optional[bool] = None, truncation: bool | None = None,
max_length: Optional[int] = None, max_length: int | None = None,
return_tensors: str = "pt", return_tensors: str = "pt",
) -> BatchFeature: ) -> BatchFeature:
"""Unified preprocessing function for Wall-X model handling text, image and video inputs. """Unified preprocessing function for Wall-X model handling text, image and video inputs.
@@ -145,9 +144,7 @@ def preprocesser_call(
""" """
# Process image inputs # Process image inputs
if images is not None and len(images) > 0: if images is not None and len(images) > 0:
image_inputs = processor.image_processor( image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
images=images, videos=None, return_tensors=return_tensors
)
image_grid_thw = image_inputs["image_grid_thw"] image_grid_thw = image_inputs["image_grid_thw"]
else: else:
image_inputs = {} image_inputs = {}
@@ -155,9 +152,7 @@ def preprocesser_call(
# Process video inputs # Process video inputs
if videos is not None: if videos is not None:
videos_inputs = processor.image_processor( videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
images=None, videos=videos, return_tensors=return_tensors
)
video_grid_thw = videos_inputs["video_grid_thw"] video_grid_thw = videos_inputs["video_grid_thw"]
else: else:
videos_inputs = {} videos_inputs = {}
@@ -183,9 +178,7 @@ def preprocesser_call(
break break
# Replace image placeholder with actual token count # Replace image placeholder with actual token count
token_count = image_grid_thw[index].prod() // merge_length token_count = image_grid_thw[index].prod() // merge_length
text[i] = text[i].replace( text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
"<|image_pad|>", "<|placeholder|>" * token_count, 1
)
index += 1 index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
@@ -197,9 +190,7 @@ def preprocesser_call(
while "<|video_pad|>" in text[i]: while "<|video_pad|>" in text[i]:
# Replace video placeholder with actual token count # Replace video placeholder with actual token count
token_count = video_grid_thw[index].prod() // merge_length token_count = video_grid_thw[index].prod() // merge_length
text[i] = text[i].replace( text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
"<|video_pad|>", "<|placeholder|>" * token_count, 1
)
index += 1 index += 1
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>") text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
@@ -221,9 +212,7 @@ def preprocesser_call(
labels = torch.full_like(text_inputs.input_ids, -100) labels = torch.full_like(text_inputs.input_ids, -100)
assistant_marker = "<|im_start|>assistant\n" assistant_marker = "<|im_start|>assistant\n"
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
assistant_tokens = processor.tokenizer( assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
"<|im_start|>assistant\n", add_special_tokens=False
).input_ids
for i in range(len(text)): for i in range(len(text)):
assistant_regions = [] assistant_regions = []
@@ -249,9 +238,7 @@ def preprocesser_call(
# From second part onwards, each part starts with assistant response # From second part onwards, each part starts with assistant response
for k in range(current_pos + 1, len(text_inputs.input_ids[i])): for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
if text_inputs.input_ids[i][k] == im_end_token_id: if text_inputs.input_ids[i][k] == im_end_token_id:
assistant_regions.append( assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
(current_pos + len(assistant_tokens), k + 2)
)
break break
current_pos += len(part_tokens) + 3 current_pos += len(part_tokens) + 3
@@ -344,7 +331,7 @@ def process_grounding_points(
coords = [new_x1, new_y1, new_x2, new_y2] coords = [new_x1, new_y1, new_x2, new_y2]
# Return processed point tag # Return processed point tag
return f'<point>[{", ".join(map(str, coords))}]</point>' return f"<point>[{', '.join(map(str, coords))}]</point>"
except (ValueError, TypeError): except (ValueError, TypeError):
# Return original content if processing fails # Return original content if processing fails
@@ -356,10 +343,10 @@ def process_grounding_points(
def get_frame_instruction( def get_frame_instruction(
instruction_info: Dict[str, Any], instruction_info: dict[str, Any],
frame_idx: Optional[int] = None, frame_idx: int | None = None,
truncate_keys: Optional[List[str]] = None, truncate_keys: list[str] | None = None,
) -> Tuple[Dict[str, Any], Optional[int]]: ) -> tuple[dict[str, Any], int | None]:
"""Extract frame-specific instruction from instruction dictionary. """Extract frame-specific instruction from instruction dictionary.
Args: Args:
@@ -388,11 +375,7 @@ def get_frame_instruction(
start_frame, end_frame = map(int, frame_range.split(" ")) start_frame, end_frame = map(int, frame_range.split(" "))
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx): if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
instruction_for_frame[key] = frame_instruction instruction_for_frame[key] = frame_instruction
if ( if truncate_keys is not None and split_end is None and key in truncate_keys:
truncate_keys is not None
and split_end is None
and key in truncate_keys
):
split_end = end_frame + 1 split_end = end_frame + 1
break break
else: else:
@@ -402,7 +385,7 @@ def get_frame_instruction(
def get_task_instruction( def get_task_instruction(
frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
) -> str: ) -> str:
"""Construct task instruction from available instruction fields using priority sampling. """Construct task instruction from available instruction fields using priority sampling.
@@ -450,13 +433,13 @@ def get_task_instruction(
def get_wallx_normal_text( def get_wallx_normal_text(
instruction_info: Dict[str, Any], instruction_info: dict[str, Any],
action_chunk_size: int, action_chunk_size: int,
frame_idx: int, frame_idx: int,
priority_order: Optional[OrderedDict] = None, priority_order: OrderedDict | None = None,
img_keys: Optional[List[str]] = None, img_keys: list[str] | None = None,
generate_subtask_ratio: float = 0.0, generate_subtask_ratio: float = 0.0,
) -> Tuple[str, bool]: ) -> tuple[str, bool]:
"""Construct complete multimodal prompt text for Wall-X model. """Construct complete multimodal prompt text for Wall-X model.
Formats input using special tokens including: Formats input using special tokens including:
@@ -488,9 +471,7 @@ def get_wallx_normal_text(
action_fast_symbol = "<|action_fast|>" action_fast_symbol = "<|action_fast|>"
# System prologue # System prologue
prologue = ( prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
)
# User request with observation # User request with observation
user_request = f"{role_start_symbol}user\nObservation:" user_request = f"{role_start_symbol}user\nObservation:"
@@ -501,9 +482,7 @@ def get_wallx_normal_text(
user_request += "\nInstruction:" user_request += "\nInstruction:"
# Get frame-specific instruction # Get frame-specific instruction
frame_instruction_info, _ = get_frame_instruction( frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
instruction_info, frame_idx=frame_idx
)
generate_subtask = False generate_subtask = False
priority_keys = ["subtask_generation", "distribute"] priority_keys = ["subtask_generation", "distribute"]
@@ -524,15 +503,11 @@ def get_wallx_normal_text(
output_instruction = frame_instruction_info[key] output_instruction = frame_instruction_info[key]
break break
assistant_output = ( assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
)
generate_subtask = True generate_subtask = True
else: else:
# Generate actions # Generate actions
instruction = get_task_instruction( instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
frame_instruction_info, priority_order=priority_order
)
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n" text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}" assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
@@ -540,7 +515,8 @@ def get_wallx_normal_text(
complete_text = prologue + user_message + assistant_output complete_text = prologue + user_message + assistant_output
return complete_text, generate_subtask return complete_text, generate_subtask
def img_key_mapping(img_keys: List[str]) -> List[str]:
def img_key_mapping(img_keys: list[str]) -> list[str]:
"""Map image keys to camera names. """Map image keys to camera names.
Args: Args:
@@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]:
if key in CAMERA_NAME_MAPPING: if key in CAMERA_NAME_MAPPING:
key = CAMERA_NAME_MAPPING[key] key = CAMERA_NAME_MAPPING[key]
else: else:
if 'view' in key: if "view" in key:
key = key.replace('_', ' ') key = key.replace("_", " ")
else: else:
key = key + " view" key = key + " view"
processed_img_keys.append(key) processed_img_keys.append(key)
return processed_img_keys return processed_img_keys
def get_action_tokens(
normalized_actions: Union[torch.Tensor, List], action_tokenizer def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
) -> List[List[str]]:
"""Convert normalized actions to action token strings. """Convert normalized actions to action token strings.
Args: Args:
@@ -590,8 +565,8 @@ def get_action_tokens(
def pad_action_token_strs( def pad_action_token_strs(
actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>" actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>"
) -> List[str]: ) -> list[str]:
"""Pad action token lists to same length and join as strings. """Pad action token lists to same length and join as strings.
Args: Args:
@@ -605,20 +580,18 @@ def pad_action_token_strs(
padded_action_strs = [] padded_action_strs = []
for tokens in actions_token_lists: for tokens in actions_token_lists:
padded_tokens = ( padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
)
padded_action_strs.append("".join(padded_tokens)) padded_action_strs.append("".join(padded_tokens))
return padded_action_strs return padded_action_strs
def replace_action_token( def replace_action_token(
text: List[str], text: list[str],
norm_action: Optional[torch.Tensor], norm_action: torch.Tensor | None,
action_tokenizer, action_tokenizer,
dof_masks: Optional[torch.Tensor] = None, dof_masks: torch.Tensor | None = None,
) -> List[str]: ) -> list[str]:
"""Replace action placeholders in text with actual action tokens. """Replace action placeholders in text with actual action tokens.
Args: Args:
@@ -632,10 +605,7 @@ def replace_action_token(
""" """
if action_tokenizer is not None and norm_action is not None: if action_tokenizer is not None and norm_action is not None:
# Extract actions based on chunk sizes and DOF masks # Extract actions based on chunk sizes and DOF masks
norm_action = [ norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
action[: 32, dof_masks[i, 0].bool()]
for i, action in enumerate(norm_action)
]
# Convert to action tokens and pad # Convert to action tokens and pad
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer) actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
@@ -658,4 +628,3 @@ def replace_action_token(
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text] text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
return text return text
+5 -2
View File
@@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402
) )
from lerobot.utils.random_utils import set_seed # noqa: E402 from lerobot.utils.random_utils import set_seed # noqa: E402
def test_policy_instantiation(): def test_policy_instantiation():
# Create config # Create config
set_seed(42) set_seed(42)
config = WallXConfig(device='cuda') config = WallXConfig(device="cuda")
# Set up input_features and output_features in the config # Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
@@ -118,6 +119,7 @@ def test_policy_instantiation():
print(f"Action prediction failed: {e}") print(f"Action prediction failed: {e}")
raise raise
def test_config_creation(): def test_config_creation():
"""Test policy config creation through factory.""" """Test policy config creation through factory."""
try: try:
@@ -130,6 +132,7 @@ def test_config_creation():
print(f"Config creation failed: {e}") print(f"Config creation failed: {e}")
raise raise
if __name__ == "__main__": if __name__ == "__main__":
test_policy_instantiation() test_policy_instantiation()
test_config_creation() test_config_creation()