fix pre-commit errors

This commit is contained in:
Geoffrey19
2025-12-17 20:27:14 +08:00
committed by Michel Aractingi
parent 9ce6dd9e25
commit a0c9a7d85d
8 changed files with 577 additions and 955 deletions
+7 -7
View File
@@ -6,12 +6,13 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag
## Model Overview
| Feature | Description |
| -------------------- | ------------------------------------------------------------------------ |
| Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
| Feature | Description |
| ------------------ | ----------------------------------------------------- | --- |
| Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
---
## Citation
@@ -32,4 +33,3 @@ If you use this work, please cite:
## License
This port follows the **Apache 2.0 License**.
@@ -49,7 +49,7 @@ class WallXConfig(PreTrainedConfig):
}
)
# ==================== Action Prediction ====================
# ==================== Action Prediction ====================
# Pretrained model paths
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
@@ -85,20 +85,16 @@ class WallXConfig(PreTrainedConfig):
)
if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError(
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
# Assign use_fast_tokenizer based on prediction_mode
if self.prediction_mode == "fast":
self.use_fast_tokenizer = True
elif self.prediction_mode == "diffusion":
self.use_fast_tokenizer = False
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
else:
raise ValueError(
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
+1 -1
View File
@@ -38,4 +38,4 @@ PRIORITY_ORDER = None
GENERATE_SUBTASK_RATIO = 0.0
MODEL_TYPE = "qwen2_5"
TOKENIZER_MAX_LENGTH = 768
TOKENIZER_MAX_LENGTH = 768
File diff suppressed because it is too large Load Diff
@@ -33,6 +33,8 @@ from lerobot.processor import (
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors(
config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors(
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
@@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
new_complementary_data["task"] = f"{task}."
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: format each
new_complementary_data["task"] = [
t if t.endswith(".") else f"{t}." for t in task
]
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
return new_complementary_data
File diff suppressed because it is too large Load Diff
+49 -80
View File
@@ -25,7 +25,7 @@ import random
import re
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any
import torch
from transformers import BatchFeature
@@ -46,11 +46,11 @@ class X2RDataProcessingConfig:
"""
# Action prediction configuration
predict_action_keys: List[str] = field(default_factory=list)
obs_action_keys: List[str] = field(default_factory=list)
predict_action_keys: list[str] = field(default_factory=list)
obs_action_keys: list[str] = field(default_factory=list)
# Image resolution settings for different views
resolution: Dict[str, int] = field(
resolution: dict[str, int] = field(
default_factory=lambda: {
"face_view": -1,
"left_wrist_view": 128,
@@ -63,7 +63,7 @@ class X2RDataProcessingConfig:
split_seed: int = 42
# Instruction handling
priority_order: Optional[Dict[str, float]] = None
priority_order: dict[str, float] | None = None
# Vision model parameters
model_type: str = "qwen2_5"
@@ -77,11 +77,9 @@ class X2RDataProcessingConfig:
"""Post-initialization validation and setup."""
# Validate train/test split
if not 0 < self.train_test_split < 1:
raise ValueError(
f"train_test_split must be between 0 and 1, got {self.train_test_split}"
)
raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
def as_dict(self) -> Dict:
def as_dict(self) -> dict:
"""Convert configuration to dictionary format.
Returns:
@@ -105,14 +103,15 @@ class X2RDataProcessingConfig:
raise ValueError(f"Unknown configuration parameter: {key}")
return self
def preprocesser_call(
processor,
images: Optional[Union[List, Any]] = None,
text: Optional[Union[str, List[str]]] = None,
videos: Optional[Union[List, Any]] = None,
padding: Union[bool, str] = False,
truncation: Optional[bool] = None,
max_length: Optional[int] = None,
images: list | Any | None = None,
text: str | list[str] | None = None,
videos: list | Any | None = None,
padding: bool | str = False,
truncation: bool | None = None,
max_length: int | None = None,
return_tensors: str = "pt",
) -> BatchFeature:
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
@@ -145,9 +144,7 @@ def preprocesser_call(
"""
# Process image inputs
if images is not None and len(images) > 0:
image_inputs = processor.image_processor(
images=images, videos=None, return_tensors=return_tensors
)
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
@@ -155,9 +152,7 @@ def preprocesser_call(
# Process video inputs
if videos is not None:
videos_inputs = processor.image_processor(
images=None, videos=videos, return_tensors=return_tensors
)
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
video_grid_thw = videos_inputs["video_grid_thw"]
else:
videos_inputs = {}
@@ -183,9 +178,7 @@ def preprocesser_call(
break
# Replace image placeholder with actual token count
token_count = image_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(
"<|image_pad|>", "<|placeholder|>" * token_count, 1
)
text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
@@ -197,9 +190,7 @@ def preprocesser_call(
while "<|video_pad|>" in text[i]:
# Replace video placeholder with actual token count
token_count = video_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(
"<|video_pad|>", "<|placeholder|>" * token_count, 1
)
text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
@@ -221,9 +212,7 @@ def preprocesser_call(
labels = torch.full_like(text_inputs.input_ids, -100)
assistant_marker = "<|im_start|>assistant\n"
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
assistant_tokens = processor.tokenizer(
"<|im_start|>assistant\n", add_special_tokens=False
).input_ids
assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
for i in range(len(text)):
assistant_regions = []
@@ -249,9 +238,7 @@ def preprocesser_call(
# From second part onwards, each part starts with assistant response
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
if text_inputs.input_ids[i][k] == im_end_token_id:
assistant_regions.append(
(current_pos + len(assistant_tokens), k + 2)
)
assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
break
current_pos += len(part_tokens) + 3
@@ -344,7 +331,7 @@ def process_grounding_points(
coords = [new_x1, new_y1, new_x2, new_y2]
# Return processed point tag
return f'<point>[{", ".join(map(str, coords))}]</point>'
return f"<point>[{', '.join(map(str, coords))}]</point>"
except (ValueError, TypeError):
# Return original content if processing fails
@@ -356,10 +343,10 @@ def process_grounding_points(
def get_frame_instruction(
instruction_info: Dict[str, Any],
frame_idx: Optional[int] = None,
truncate_keys: Optional[List[str]] = None,
) -> Tuple[Dict[str, Any], Optional[int]]:
instruction_info: dict[str, Any],
frame_idx: int | None = None,
truncate_keys: list[str] | None = None,
) -> tuple[dict[str, Any], int | None]:
"""Extract frame-specific instruction from instruction dictionary.
Args:
@@ -388,11 +375,7 @@ def get_frame_instruction(
start_frame, end_frame = map(int, frame_range.split(" "))
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
instruction_for_frame[key] = frame_instruction
if (
truncate_keys is not None
and split_end is None
and key in truncate_keys
):
if truncate_keys is not None and split_end is None and key in truncate_keys:
split_end = end_frame + 1
break
else:
@@ -402,7 +385,7 @@ def get_frame_instruction(
def get_task_instruction(
frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None
frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
) -> str:
"""Construct task instruction from available instruction fields using priority sampling.
@@ -450,13 +433,13 @@ def get_task_instruction(
def get_wallx_normal_text(
instruction_info: Dict[str, Any],
instruction_info: dict[str, Any],
action_chunk_size: int,
frame_idx: int,
priority_order: Optional[OrderedDict] = None,
img_keys: Optional[List[str]] = None,
priority_order: OrderedDict | None = None,
img_keys: list[str] | None = None,
generate_subtask_ratio: float = 0.0,
) -> Tuple[str, bool]:
) -> tuple[str, bool]:
"""Construct complete multimodal prompt text for Wall-X model.
Formats input using special tokens including:
@@ -488,9 +471,7 @@ def get_wallx_normal_text(
action_fast_symbol = "<|action_fast|>"
# System prologue
prologue = (
f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
)
prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
# User request with observation
user_request = f"{role_start_symbol}user\nObservation:"
@@ -501,9 +482,7 @@ def get_wallx_normal_text(
user_request += "\nInstruction:"
# Get frame-specific instruction
frame_instruction_info, _ = get_frame_instruction(
instruction_info, frame_idx=frame_idx
)
frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
generate_subtask = False
priority_keys = ["subtask_generation", "distribute"]
@@ -524,15 +503,11 @@ def get_wallx_normal_text(
output_instruction = frame_instruction_info[key]
break
assistant_output = (
f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
)
assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
generate_subtask = True
else:
# Generate actions
instruction = get_task_instruction(
frame_instruction_info, priority_order=priority_order
)
instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
@@ -540,7 +515,8 @@ def get_wallx_normal_text(
complete_text = prologue + user_message + assistant_output
return complete_text, generate_subtask
def img_key_mapping(img_keys: List[str]) -> List[str]:
def img_key_mapping(img_keys: list[str]) -> list[str]:
"""Map image keys to camera names.
Args:
@@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]:
if key in CAMERA_NAME_MAPPING:
key = CAMERA_NAME_MAPPING[key]
else:
if 'view' in key:
key = key.replace('_', ' ')
if "view" in key:
key = key.replace("_", " ")
else:
key = key + " view"
processed_img_keys.append(key)
return processed_img_keys
def get_action_tokens(
normalized_actions: Union[torch.Tensor, List], action_tokenizer
) -> List[List[str]]:
def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
"""Convert normalized actions to action token strings.
Args:
@@ -590,8 +565,8 @@ def get_action_tokens(
def pad_action_token_strs(
actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>"
) -> List[str]:
actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>"
) -> list[str]:
"""Pad action token lists to same length and join as strings.
Args:
@@ -605,20 +580,18 @@ def pad_action_token_strs(
padded_action_strs = []
for tokens in actions_token_lists:
padded_tokens = (
tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
)
padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
padded_action_strs.append("".join(padded_tokens))
return padded_action_strs
def replace_action_token(
text: List[str],
norm_action: Optional[torch.Tensor],
text: list[str],
norm_action: torch.Tensor | None,
action_tokenizer,
dof_masks: Optional[torch.Tensor] = None,
) -> List[str]:
dof_masks: torch.Tensor | None = None,
) -> list[str]:
"""Replace action placeholders in text with actual action tokens.
Args:
@@ -632,10 +605,7 @@ def replace_action_token(
"""
if action_tokenizer is not None and norm_action is not None:
# Extract actions based on chunk sizes and DOF masks
norm_action = [
action[: 32, dof_masks[i, 0].bool()]
for i, action in enumerate(norm_action)
]
norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
# Convert to action tokens and pad
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
@@ -658,4 +628,3 @@ def replace_action_token(
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
return text
+5 -2
View File
@@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402
)
from lerobot.utils.random_utils import set_seed # noqa: E402
def test_policy_instantiation():
# Create config
set_seed(42)
config = WallXConfig(device='cuda')
config = WallXConfig(device="cuda")
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -118,6 +119,7 @@ def test_policy_instantiation():
print(f"Action prediction failed: {e}")
raise
def test_config_creation():
"""Test policy config creation through factory."""
try:
@@ -130,6 +132,7 @@ def test_config_creation():
print(f"Config creation failed: {e}")
raise
if __name__ == "__main__":
test_policy_instantiation()
test_config_creation()
test_config_creation()