mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix pre-commit errors
This commit is contained in:
committed by
Michel Aractingi
parent
9ce6dd9e25
commit
a0c9a7d85d
@@ -6,12 +6,13 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | Description |
|
||||
| -------------------- | ------------------------------------------------------------------------ |
|
||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||
| Feature | Description |
|
||||
| ------------------ | ----------------------------------------------------- | --- |
|
||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
@@ -32,4 +33,3 @@ If you use this work, please cite:
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**.
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class WallXConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# ==================== Action Prediction ====================
|
||||
# ==================== Action Prediction ====================
|
||||
# Pretrained model paths
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
@@ -85,20 +85,16 @@ class WallXConfig(PreTrainedConfig):
|
||||
)
|
||||
|
||||
if self.prediction_mode not in ["diffusion", "fast"]:
|
||||
raise ValueError(
|
||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
||||
)
|
||||
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||
|
||||
# Assign use_fast_tokenizer based on prediction_mode
|
||||
if self.prediction_mode == "fast":
|
||||
self.use_fast_tokenizer = True
|
||||
elif self.prediction_mode == "diffusion":
|
||||
self.use_fast_tokenizer = False
|
||||
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
||||
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
||||
)
|
||||
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
|
||||
@@ -38,4 +38,4 @@ PRIORITY_ORDER = None
|
||||
GENERATE_SUBTASK_RATIO = 0.0
|
||||
MODEL_TYPE = "qwen2_5"
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 768
|
||||
TOKENIZER_MAX_LENGTH = 768
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -33,6 +33,8 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_wall_x_pre_post_processors(
|
||||
config: WallXConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
@@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors(
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
@@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
||||
new_complementary_data["task"] = f"{task}."
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: format each
|
||||
new_complementary_data["task"] = [
|
||||
t if t.endswith(".") else f"{t}." for t in task
|
||||
]
|
||||
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,7 @@ import random
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature
|
||||
@@ -46,11 +46,11 @@ class X2RDataProcessingConfig:
|
||||
"""
|
||||
|
||||
# Action prediction configuration
|
||||
predict_action_keys: List[str] = field(default_factory=list)
|
||||
obs_action_keys: List[str] = field(default_factory=list)
|
||||
predict_action_keys: list[str] = field(default_factory=list)
|
||||
obs_action_keys: list[str] = field(default_factory=list)
|
||||
|
||||
# Image resolution settings for different views
|
||||
resolution: Dict[str, int] = field(
|
||||
resolution: dict[str, int] = field(
|
||||
default_factory=lambda: {
|
||||
"face_view": -1,
|
||||
"left_wrist_view": 128,
|
||||
@@ -63,7 +63,7 @@ class X2RDataProcessingConfig:
|
||||
split_seed: int = 42
|
||||
|
||||
# Instruction handling
|
||||
priority_order: Optional[Dict[str, float]] = None
|
||||
priority_order: dict[str, float] | None = None
|
||||
|
||||
# Vision model parameters
|
||||
model_type: str = "qwen2_5"
|
||||
@@ -77,11 +77,9 @@ class X2RDataProcessingConfig:
|
||||
"""Post-initialization validation and setup."""
|
||||
# Validate train/test split
|
||||
if not 0 < self.train_test_split < 1:
|
||||
raise ValueError(
|
||||
f"train_test_split must be between 0 and 1, got {self.train_test_split}"
|
||||
)
|
||||
raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
|
||||
|
||||
def as_dict(self) -> Dict:
|
||||
def as_dict(self) -> dict:
|
||||
"""Convert configuration to dictionary format.
|
||||
|
||||
Returns:
|
||||
@@ -105,14 +103,15 @@ class X2RDataProcessingConfig:
|
||||
raise ValueError(f"Unknown configuration parameter: {key}")
|
||||
return self
|
||||
|
||||
|
||||
def preprocesser_call(
|
||||
processor,
|
||||
images: Optional[Union[List, Any]] = None,
|
||||
text: Optional[Union[str, List[str]]] = None,
|
||||
videos: Optional[Union[List, Any]] = None,
|
||||
padding: Union[bool, str] = False,
|
||||
truncation: Optional[bool] = None,
|
||||
max_length: Optional[int] = None,
|
||||
images: list | Any | None = None,
|
||||
text: str | list[str] | None = None,
|
||||
videos: list | Any | None = None,
|
||||
padding: bool | str = False,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
return_tensors: str = "pt",
|
||||
) -> BatchFeature:
|
||||
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
|
||||
@@ -145,9 +144,7 @@ def preprocesser_call(
|
||||
"""
|
||||
# Process image inputs
|
||||
if images is not None and len(images) > 0:
|
||||
image_inputs = processor.image_processor(
|
||||
images=images, videos=None, return_tensors=return_tensors
|
||||
)
|
||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
@@ -155,9 +152,7 @@ def preprocesser_call(
|
||||
|
||||
# Process video inputs
|
||||
if videos is not None:
|
||||
videos_inputs = processor.image_processor(
|
||||
images=None, videos=videos, return_tensors=return_tensors
|
||||
)
|
||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
@@ -183,9 +178,7 @@ def preprocesser_call(
|
||||
break
|
||||
# Replace image placeholder with actual token count
|
||||
token_count = image_grid_thw[index].prod() // merge_length
|
||||
text[i] = text[i].replace(
|
||||
"<|image_pad|>", "<|placeholder|>" * token_count, 1
|
||||
)
|
||||
text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
||||
|
||||
@@ -197,9 +190,7 @@ def preprocesser_call(
|
||||
while "<|video_pad|>" in text[i]:
|
||||
# Replace video placeholder with actual token count
|
||||
token_count = video_grid_thw[index].prod() // merge_length
|
||||
text[i] = text[i].replace(
|
||||
"<|video_pad|>", "<|placeholder|>" * token_count, 1
|
||||
)
|
||||
text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
||||
|
||||
@@ -221,9 +212,7 @@ def preprocesser_call(
|
||||
labels = torch.full_like(text_inputs.input_ids, -100)
|
||||
assistant_marker = "<|im_start|>assistant\n"
|
||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
assistant_tokens = processor.tokenizer(
|
||||
"<|im_start|>assistant\n", add_special_tokens=False
|
||||
).input_ids
|
||||
assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
|
||||
|
||||
for i in range(len(text)):
|
||||
assistant_regions = []
|
||||
@@ -249,9 +238,7 @@ def preprocesser_call(
|
||||
# From second part onwards, each part starts with assistant response
|
||||
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
|
||||
if text_inputs.input_ids[i][k] == im_end_token_id:
|
||||
assistant_regions.append(
|
||||
(current_pos + len(assistant_tokens), k + 2)
|
||||
)
|
||||
assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
|
||||
break
|
||||
current_pos += len(part_tokens) + 3
|
||||
|
||||
@@ -344,7 +331,7 @@ def process_grounding_points(
|
||||
coords = [new_x1, new_y1, new_x2, new_y2]
|
||||
|
||||
# Return processed point tag
|
||||
return f'<point>[{", ".join(map(str, coords))}]</point>'
|
||||
return f"<point>[{', '.join(map(str, coords))}]</point>"
|
||||
|
||||
except (ValueError, TypeError):
|
||||
# Return original content if processing fails
|
||||
@@ -356,10 +343,10 @@ def process_grounding_points(
|
||||
|
||||
|
||||
def get_frame_instruction(
|
||||
instruction_info: Dict[str, Any],
|
||||
frame_idx: Optional[int] = None,
|
||||
truncate_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Dict[str, Any], Optional[int]]:
|
||||
instruction_info: dict[str, Any],
|
||||
frame_idx: int | None = None,
|
||||
truncate_keys: list[str] | None = None,
|
||||
) -> tuple[dict[str, Any], int | None]:
|
||||
"""Extract frame-specific instruction from instruction dictionary.
|
||||
|
||||
Args:
|
||||
@@ -388,11 +375,7 @@ def get_frame_instruction(
|
||||
start_frame, end_frame = map(int, frame_range.split(" "))
|
||||
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
|
||||
instruction_for_frame[key] = frame_instruction
|
||||
if (
|
||||
truncate_keys is not None
|
||||
and split_end is None
|
||||
and key in truncate_keys
|
||||
):
|
||||
if truncate_keys is not None and split_end is None and key in truncate_keys:
|
||||
split_end = end_frame + 1
|
||||
break
|
||||
else:
|
||||
@@ -402,7 +385,7 @@ def get_frame_instruction(
|
||||
|
||||
|
||||
def get_task_instruction(
|
||||
frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None
|
||||
frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
|
||||
) -> str:
|
||||
"""Construct task instruction from available instruction fields using priority sampling.
|
||||
|
||||
@@ -450,13 +433,13 @@ def get_task_instruction(
|
||||
|
||||
|
||||
def get_wallx_normal_text(
|
||||
instruction_info: Dict[str, Any],
|
||||
instruction_info: dict[str, Any],
|
||||
action_chunk_size: int,
|
||||
frame_idx: int,
|
||||
priority_order: Optional[OrderedDict] = None,
|
||||
img_keys: Optional[List[str]] = None,
|
||||
priority_order: OrderedDict | None = None,
|
||||
img_keys: list[str] | None = None,
|
||||
generate_subtask_ratio: float = 0.0,
|
||||
) -> Tuple[str, bool]:
|
||||
) -> tuple[str, bool]:
|
||||
"""Construct complete multimodal prompt text for Wall-X model.
|
||||
|
||||
Formats input using special tokens including:
|
||||
@@ -488,9 +471,7 @@ def get_wallx_normal_text(
|
||||
action_fast_symbol = "<|action_fast|>"
|
||||
|
||||
# System prologue
|
||||
prologue = (
|
||||
f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
||||
)
|
||||
prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
||||
|
||||
# User request with observation
|
||||
user_request = f"{role_start_symbol}user\nObservation:"
|
||||
@@ -501,9 +482,7 @@ def get_wallx_normal_text(
|
||||
user_request += "\nInstruction:"
|
||||
|
||||
# Get frame-specific instruction
|
||||
frame_instruction_info, _ = get_frame_instruction(
|
||||
instruction_info, frame_idx=frame_idx
|
||||
)
|
||||
frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
|
||||
|
||||
generate_subtask = False
|
||||
priority_keys = ["subtask_generation", "distribute"]
|
||||
@@ -524,15 +503,11 @@ def get_wallx_normal_text(
|
||||
output_instruction = frame_instruction_info[key]
|
||||
break
|
||||
|
||||
assistant_output = (
|
||||
f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
||||
)
|
||||
assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
||||
generate_subtask = True
|
||||
else:
|
||||
# Generate actions
|
||||
instruction = get_task_instruction(
|
||||
frame_instruction_info, priority_order=priority_order
|
||||
)
|
||||
instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
|
||||
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
|
||||
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
||||
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
|
||||
@@ -540,7 +515,8 @@ def get_wallx_normal_text(
|
||||
complete_text = prologue + user_message + assistant_output
|
||||
return complete_text, generate_subtask
|
||||
|
||||
def img_key_mapping(img_keys: List[str]) -> List[str]:
|
||||
|
||||
def img_key_mapping(img_keys: list[str]) -> list[str]:
|
||||
"""Map image keys to camera names.
|
||||
|
||||
Args:
|
||||
@@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]:
|
||||
if key in CAMERA_NAME_MAPPING:
|
||||
key = CAMERA_NAME_MAPPING[key]
|
||||
else:
|
||||
if 'view' in key:
|
||||
key = key.replace('_', ' ')
|
||||
if "view" in key:
|
||||
key = key.replace("_", " ")
|
||||
else:
|
||||
key = key + " view"
|
||||
processed_img_keys.append(key)
|
||||
return processed_img_keys
|
||||
|
||||
def get_action_tokens(
|
||||
normalized_actions: Union[torch.Tensor, List], action_tokenizer
|
||||
) -> List[List[str]]:
|
||||
|
||||
def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
|
||||
"""Convert normalized actions to action token strings.
|
||||
|
||||
Args:
|
||||
@@ -590,8 +565,8 @@ def get_action_tokens(
|
||||
|
||||
|
||||
def pad_action_token_strs(
|
||||
actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>"
|
||||
) -> List[str]:
|
||||
actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>"
|
||||
) -> list[str]:
|
||||
"""Pad action token lists to same length and join as strings.
|
||||
|
||||
Args:
|
||||
@@ -605,20 +580,18 @@ def pad_action_token_strs(
|
||||
padded_action_strs = []
|
||||
|
||||
for tokens in actions_token_lists:
|
||||
padded_tokens = (
|
||||
tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
||||
)
|
||||
padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
||||
padded_action_strs.append("".join(padded_tokens))
|
||||
|
||||
return padded_action_strs
|
||||
|
||||
|
||||
def replace_action_token(
|
||||
text: List[str],
|
||||
norm_action: Optional[torch.Tensor],
|
||||
text: list[str],
|
||||
norm_action: torch.Tensor | None,
|
||||
action_tokenizer,
|
||||
dof_masks: Optional[torch.Tensor] = None,
|
||||
) -> List[str]:
|
||||
dof_masks: torch.Tensor | None = None,
|
||||
) -> list[str]:
|
||||
"""Replace action placeholders in text with actual action tokens.
|
||||
|
||||
Args:
|
||||
@@ -632,10 +605,7 @@ def replace_action_token(
|
||||
"""
|
||||
if action_tokenizer is not None and norm_action is not None:
|
||||
# Extract actions based on chunk sizes and DOF masks
|
||||
norm_action = [
|
||||
action[: 32, dof_masks[i, 0].bool()]
|
||||
for i, action in enumerate(norm_action)
|
||||
]
|
||||
norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
|
||||
|
||||
# Convert to action tokens and pad
|
||||
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
|
||||
@@ -658,4 +628,3 @@ def replace_action_token(
|
||||
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
|
||||
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
config = WallXConfig(device='cuda')
|
||||
config = WallXConfig(device="cuda")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -118,6 +119,7 @@ def test_policy_instantiation():
|
||||
print(f"Action prediction failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
@@ -130,6 +132,7 @@ def test_config_creation():
|
||||
print(f"Config creation failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_policy_instantiation()
|
||||
test_config_creation()
|
||||
test_config_creation()
|
||||
|
||||
Reference in New Issue
Block a user