reduce to least config and params & pass lerobot basic test

This commit is contained in:
Geoffrey19
2025-12-04 15:38:04 +08:00
parent 78995621fa
commit b4a7586b27
7 changed files with 349 additions and 832 deletions
@@ -13,13 +13,11 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("wall_x")
@@ -34,48 +32,15 @@ class WallXConfig(PreTrainedConfig):
This config supports multi-modal learning with vision, language, and action data.
"""
# ==================== Model and Paths Configuration ====================
# Logging
log_name: str = "wall_x_training"
log_project: str = "vla_training"
model_type: str = "wall-oss"
# Pretrained model paths
pretrained_wallx_path: str | None = None # Path to pretrained Wall-X model
save_path: str | None = None # Path to save checkpoints
processor_path: str | None = None # Path to processor (defaults to pretrained_wallx_path)
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
# Tokenizer settings
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
# ==================== Profiling Configuration ====================
profile: bool = False
profile_save_path: str | None = None
profile_wait_iters: int = 10
profile_warmup_iters: int = 5
profile_active_iters: int = 2
# ==================== Training Hyperparameters ====================
num_warmup_steps: int = 100
num_training_steps: int = 64000000
learning_rate: float = 5e-5
min_lr: float = 5e-5
num_epoch: int = 100
gradient_accumulation_steps: int = 32
batch_size_per_gpu: int = 8
padding_side: str = "left"
epoch_save_interval: int = 10
# Training optimization
fsdp2: bool = False
torch_compile: bool = False
# ==================== Input / Output Structure ====================
n_obs_steps: int = 1
chunk_size: int = 32 # action_horizon in wall-x
n_action_steps: int = 32
# Action dimension - wall-x uses 20
max_action_dim: int = 20
max_state_dim: int = 20 # For proprioception
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
@@ -84,101 +49,17 @@ class WallXConfig(PreTrainedConfig):
}
)
# Action dimension - wall-x uses hardcoded 20
max_action_dim: int = 20
max_state_dim: int = 20 # For proprioception
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] | None = None # wall-x uses Qwen processor
# Tokenizer
tokenizer_max_length: int = 256
# ==================== Model Architecture ====================
vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
load_vlm_weights: bool = True
# Vision config
vision_config: dict = field(default_factory=lambda: {
"depth": 32,
"hidden_size": 3584,
"hidden_act": "silu",
"intermediate_size": 3420,
"num_heads": 16,
"patch_size": 14,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
"window_size": 112,
"out_hidden_size": 3584,
})
# Language model config
hidden_size: int = 3584 # 8192 for 7B model
intermediate_size: int = 18944 # 29568 for 7B model
num_hidden_layers: int = 36 # 80 for 7B model
num_attention_heads: int = 28 # 64 for 7B model
num_key_value_heads: int = 4 # 8 for 7B model
vocab_size: int = 152064
# ==================== Action Prediction ====================
# Action prediction mode: "flow" or "fast"
prediction_mode: str = "flow"
# Pretrained model paths
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
# Flow matching parameters
noise_scheduler: dict = field(default_factory=lambda: {
"beta_alpha": 1.5, # Beta distribution concentration1
"beta_beta": 1.0, # Beta distribution concentration0
"s": 0.999, # Scaling factor for time
})
# Action prediction mode: "diffusion" or "fast"
prediction_mode: str = "diffusion"
# Decoding parameters
num_inference_timesteps: int = 10 # Number of ODE solver steps
ode_solver_method: str = "euler" # ODE solver method
# Tokenizer settings
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
# ==================== Robot Configuration ====================
# Degrees of freedom configuration - defines action space
dof_config: dict = field(default_factory=lambda: {
"left_ee_pos": 3,
"left_ee_rot": 3,
"left_gripper": 1,
"right_ee_pos": 3,
"right_ee_rot": 3,
"right_gripper": 1,
})
# Proprioception configuration (typically mirrors dof_config)
agent_pos_config: dict = field(default_factory=lambda: {
"left_ee_pos": 3,
"left_ee_rot": 3,
"left_gripper": 1,
"right_ee_pos": 3,
"right_ee_rot": 3,
"right_gripper": 1,
})
# Customized robot configuration
enable_customized_robot_config: bool = False
customized_robot_config: dict = field(default_factory=lambda: {
"name": "",
"customized_dof_config": {},
"customized_agent_pos_config": {},
})
# Normalization statistics path
norm_stats_path: str | None = None
# ==================== MoE Configuration ====================
num_experts: int = 4
attention_moe: bool = False
mlp_moe: bool = False
# ==================== Finetuning Settings ====================
freeze_vision_encoder: bool = True
train_expert_only: bool = False # wall-x trains more components
train_action_head: bool = True
# Cache
use_cache: bool = True
# ==================== Optimizer Presets ====================
optimizer_lr: float = 2e-5
@@ -191,44 +72,6 @@ class WallXConfig(PreTrainedConfig):
scheduler_decay_steps: int = 100000
scheduler_decay_lr: float = 1e-6
# ==================== Dataset Configuration ====================
# Dataset-specific normalization statistics
action_statistics: dict = field(default_factory=dict)
# Data configuration
data_config: dict = field(default_factory=lambda: {
"use_lerobot": True,
"lerobot_config": {
"repo_id": "",
"root": None,
"episodes": None,
"image_transforms": None,
"delta_timestamps": None,
"tolerance_s": 1e-4,
"revision": None,
"force_cache_sync": False,
"download_videos": True,
"video_backend": None,
},
"action_horizon": 32,
"train_test_split": 0.95,
"obs_action_keys": [],
"predict_action_keys": [],
"resolution": {
"face_view": 256,
"left_wrist_view": 256,
"right_wrist_view": 256,
"move1_view": 256,
"move2_view": 256,
"top_view": 256,
"wall_view": 256,
"multi_modal": 256,
},
})
# ==================== Resume Configuration ====================
resume_config: dict | None = field(default_factory=lambda: None)
def __post_init__(self):
super().__post_init__()
@@ -239,243 +82,55 @@ class WallXConfig(PreTrainedConfig):
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.prediction_mode not in ["flow", "fast"]:
if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError(
f"prediction_mode must be 'flow' or 'fast', got {self.prediction_mode}"
)
# Validate dof_config total doesn't exceed max_action_dim
total_dof = sum(self.dof_config.values())
if total_dof > self.max_action_dim:
raise ValueError(
f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})"
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
# Sync prediction_mode with use_fast_tokenizer
if self.use_fast_tokenizer:
self.prediction_mode = "fast"
else:
self.prediction_mode = "flow"
self.prediction_mode = "diffusion"
def get_train_config(self) -> dict:
"""
Extract the complete train_config dictionary matching the YAML training configuration format.
def validate_features(self) -> None:
"""Validate and set up input/output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"Wall-X policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
This method constructs the full train_config from WallXConfig fields, suitable for
training scripts and Qwen2_5_VLMoEForAction.from_pretrained.
Returns:
dict: Complete training configuration matching YAML structure.
"""
# Build customized_robot_config
if self.enable_customized_robot_config and self.customized_robot_config:
customized_robot_config = {
"name": self.customized_robot_config.get("name", ""),
"customized_dof_config": self.customized_robot_config.get(
"customized_dof_config", self.dof_config
),
"customized_agent_pos_config": self.customized_robot_config.get(
"customized_agent_pos_config", self.agent_pos_config
),
}
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
else:
customized_robot_config = {
"name": self.data_config.get("lerobot_config", {}).get("repo_id", ""),
"customized_dof_config": self.dof_config,
"customized_agent_pos_config": self.agent_pos_config,
}
state_shape = self.input_features["observation.state"].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
f"Either reduce state dimension or increase max_state_dim in config."
)
train_config = {
# Model and paths configuration
"log_name": self.log_name,
"log_project": self.log_project,
"model_type": self.model_type,
"pretrained_wallx_path": self.pretrained_wallx_path,
"save_path": self.save_path,
"use_fast_tokenizer": self.use_fast_tokenizer,
"action_tokenizer_path": self.action_tokenizer_path,
# Profiling configuration
"profile": self.profile,
"profile_save_path": self.profile_save_path,
"profile_wait_iters": self.profile_wait_iters,
"profile_warmup_iters": self.profile_warmup_iters,
"profile_active_iters": self.profile_active_iters,
# Training hyperparameters
"num_warmup_steps": self.num_warmup_steps,
"num_training_steps": self.num_training_steps,
"learning_rate": self.learning_rate,
"min_lr": self.min_lr,
"num_epoch": self.num_epoch,
"gradient_accumulation_steps": self.gradient_accumulation_steps,
"batch_size_per_gpu": self.batch_size_per_gpu,
"padding_side": self.padding_side,
"epoch_save_interval": self.epoch_save_interval,
# Training optimization
"FSDP2": self.fsdp2,
"torch_compile": self.torch_compile,
# Robot configuration
"dof_config": self.dof_config,
"agent_pos_config": self.agent_pos_config,
# Normalization stats
"norm_stats_path": self.norm_stats_path,
# Customized robot config
"enable_customized_robot_config": self.enable_customized_robot_config,
"customized_robot_config": customized_robot_config,
# Resume configuration
"resume": self.resume_config,
# Data configuration
"data": self.data_config,
}
return train_config
def get_dataload_config(self) -> dict:
"""
Extract data loading configuration from config.
Returns:
dict: Data loading configuration for preprocessing.
"""
return {
"action_horizon": self.data_config.get("action_horizon", self.chunk_size),
"train_test_split": self.data_config.get("train_test_split", 0.95),
"split_seed": 42,
"predict_action_keys": self.data_config.get("predict_action_keys", []),
"obs_action_keys": self.data_config.get("obs_action_keys", []),
"resolution": self.data_config.get("resolution", {}),
"priority_order": None,
"max_length": self.tokenizer_max_length,
}
def get_lerobot_config(self) -> dict:
"""
Extract LeRobot dataset configuration.
Returns:
dict: LeRobot dataset configuration.
"""
return self.data_config.get("lerobot_config", {})
@classmethod
def from_yaml_dict(cls, yaml_dict: dict) -> "WallXConfig":
"""
Create a WallXConfig from a YAML configuration dictionary.
Args:
yaml_dict: Dictionary loaded from YAML training config file.
Returns:
WallXConfig instance with values from YAML.
"""
config_kwargs = {}
# Model and paths
if "log_name" in yaml_dict:
config_kwargs["log_name"] = yaml_dict["log_name"]
if "log_project" in yaml_dict:
config_kwargs["log_project"] = yaml_dict["log_project"]
if "model_type" in yaml_dict:
config_kwargs["model_type"] = yaml_dict["model_type"]
if "pretrained_wallx_path" in yaml_dict:
config_kwargs["pretrained_wallx_path"] = yaml_dict["pretrained_wallx_path"]
if "save_path" in yaml_dict:
config_kwargs["save_path"] = yaml_dict["save_path"]
if "use_fast_tokenizer" in yaml_dict:
config_kwargs["use_fast_tokenizer"] = yaml_dict["use_fast_tokenizer"]
if "action_tokenizer_path" in yaml_dict:
config_kwargs["action_tokenizer_path"] = yaml_dict["action_tokenizer_path"]
# Profiling
if "profile" in yaml_dict:
config_kwargs["profile"] = yaml_dict["profile"]
if "profile_save_path" in yaml_dict:
config_kwargs["profile_save_path"] = yaml_dict["profile_save_path"]
if "profile_wait_iters" in yaml_dict:
config_kwargs["profile_wait_iters"] = yaml_dict["profile_wait_iters"]
if "profile_warmup_iters" in yaml_dict:
config_kwargs["profile_warmup_iters"] = yaml_dict["profile_warmup_iters"]
if "profile_active_iters" in yaml_dict:
config_kwargs["profile_active_iters"] = yaml_dict["profile_active_iters"]
# Training hyperparameters
if "num_warmup_steps" in yaml_dict:
config_kwargs["num_warmup_steps"] = yaml_dict["num_warmup_steps"]
config_kwargs["scheduler_warmup_steps"] = yaml_dict["num_warmup_steps"]
if "num_training_steps" in yaml_dict:
config_kwargs["num_training_steps"] = yaml_dict["num_training_steps"]
config_kwargs["scheduler_decay_steps"] = yaml_dict["num_training_steps"]
if "learning_rate" in yaml_dict:
config_kwargs["learning_rate"] = yaml_dict["learning_rate"]
config_kwargs["optimizer_lr"] = yaml_dict["learning_rate"]
if "min_lr" in yaml_dict:
config_kwargs["min_lr"] = yaml_dict["min_lr"]
config_kwargs["scheduler_decay_lr"] = yaml_dict["min_lr"]
if "num_epoch" in yaml_dict:
config_kwargs["num_epoch"] = yaml_dict["num_epoch"]
if "gradient_accumulation_steps" in yaml_dict:
config_kwargs["gradient_accumulation_steps"] = yaml_dict["gradient_accumulation_steps"]
if "batch_size_per_gpu" in yaml_dict:
config_kwargs["batch_size_per_gpu"] = yaml_dict["batch_size_per_gpu"]
if "padding_side" in yaml_dict:
config_kwargs["padding_side"] = yaml_dict["padding_side"]
if "epoch_save_interval" in yaml_dict:
config_kwargs["epoch_save_interval"] = yaml_dict["epoch_save_interval"]
# Training optimization
if "FSDP2" in yaml_dict:
config_kwargs["fsdp2"] = yaml_dict["FSDP2"]
if "torch_compile" in yaml_dict:
config_kwargs["torch_compile"] = yaml_dict["torch_compile"]
# Robot configuration
if "dof_config" in yaml_dict:
config_kwargs["dof_config"] = yaml_dict["dof_config"]
if "agent_pos_config" in yaml_dict:
config_kwargs["agent_pos_config"] = yaml_dict["agent_pos_config"]
# Normalization stats
if "norm_stats_path" in yaml_dict:
config_kwargs["norm_stats_path"] = yaml_dict["norm_stats_path"]
# Customized robot config
if "enable_customized_robot_config" in yaml_dict:
config_kwargs["enable_customized_robot_config"] = yaml_dict["enable_customized_robot_config"]
if "customized_robot_config" in yaml_dict:
config_kwargs["customized_robot_config"] = yaml_dict["customized_robot_config"]
# Resume config
if "resume" in yaml_dict:
config_kwargs["resume_config"] = yaml_dict["resume"]
# Data configuration
if "data" in yaml_dict:
data = yaml_dict["data"]
data_config = {
"use_lerobot": data.get("use_lerobot", True),
"action_horizon": data.get("action_horizon", 32),
"train_test_split": data.get("train_test_split", 0.95),
"obs_action_keys": data.get("obs_action_keys", []),
"predict_action_keys": data.get("predict_action_keys", []),
"resolution": data.get("resolution", {}),
}
if "lerobot_config" in data:
data_config["lerobot_config"] = data["lerobot_config"]
config_kwargs["data_config"] = data_config
# Set chunk_size from action_horizon
if "action_horizon" in data:
config_kwargs["chunk_size"] = data["action_horizon"]
config_kwargs["n_action_steps"] = data["action_horizon"]
return cls(**config_kwargs)
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
else:
action_shape = self.output_features["action"].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
f"Either reduce action dimension or increase max_action_dim in config."
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
@@ -496,7 +151,7 @@ class WallXConfig(PreTrainedConfig):
@property
def observation_delta_indices(self) -> list:
return [0]
return None
@property
def action_delta_indices(self) -> list:
+13 -7
View File
@@ -16,15 +16,9 @@
"""
Wall-X Constants and Configuration Data.
Contains dataset names, key mappings, frequency mappings, and action statistics
for cross-embodiment robotic control.
"""
from pathlib import Path
# Add wall-x repo to path if available
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION
CAMERA_NAME_MAPPING = {
"face_view": "front view",
@@ -35,3 +29,15 @@ CAMERA_NAME_MAPPING = {
"wall_view": "wall view",
"top_view": "top view",
}
RESOLUTION = 256
# Parameters for preprocessing
MAX_PIXELS = 16384 * 28 * 28
MIN_PIXELS = 4 * 28 * 28
IMAGE_FACTOR = 28
PRIORITY_ORDER = None
GENERATE_SUBTASK_RATIO = 0.0
MODEL_TYPE = "qwen2_5"
TOKENIZER_MAX_LENGTH = 768
+118 -355
View File
@@ -38,6 +38,7 @@ import builtins
import glob
import math
import os
from os import PathLike
import sys
from collections import deque
from pathlib import Path
@@ -57,16 +58,8 @@ from torchdiffeq import odeint
from transformers import AutoConfig, AutoProcessor
from transformers.activations import ACT2FN
from transformers.cache_utils import (
Cache,
DynamicCache,
SlidingWindowCache,
StaticCache,
)
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_torchdynamo_compiling, logging
from transformers import AutoProcessor, BatchFeature
from qwen_vl_utils.vision_process import smart_resize
@@ -80,29 +73,18 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
from lerobot.policies.wall_x.utils import *
from lerobot.policies.wall_x.constant import *
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2RMSNorm,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLRotaryEmbedding,
Qwen2_5_VLPreTrainedModel,
Qwen2_5_VLForConditionalGeneration,
)
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLDecoderLayer_with_MoE,
Qwen2_5_VLACausalLMOutputWithPast,
Qwen2_5_VLMoEModel,
)
logger = logging.get_logger(__name__)
# Add wall-x repo to path if available
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
if WALL_X_PATH.exists():
sys.path.insert(0, str(WALL_X_PATH))
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embedding for diffusion timesteps."""
@@ -129,7 +111,7 @@ class ActionHead(nn.Module):
for action sequence prediction.
"""
def __init__(self, config: WallXConfig):
def __init__(self, config):
super().__init__()
self.config = config
@@ -138,10 +120,9 @@ class ActionHead(nn.Module):
self.hidden_size = config.hidden_size
# Beta distribution for noise scheduling
noise_config = config.noise_scheduler
self.beta_alpha = noise_config.get("beta_alpha", 1.5)
self.beta_beta = noise_config.get("beta_beta", 1.0)
self.s = noise_config.get("s", 0.999)
self.beta_alpha = 1.5
self.beta_beta = 1.0
self.s = 0.999
# Sinusoidal timestep embedding
self.time_embed = SinusoidalPosEmb(config.hidden_size)
@@ -277,12 +258,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
@classmethod
def from_pretrained(
cls,
pretrained_model_path,
train_config,
config_path=None,
processor_path=None,
pretrained_name_or_path,
config=None,
action_tokenizer_path=None,
**kwargs,
cache_dir: str | PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
strict: bool = False,
**kwargs: Any
):
"""
Load model from pretrained model path.
@@ -290,17 +275,24 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
Args:
pretrained_model_path (str): Model directory path containing model.safetensors file
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
processor_path (str, optional): Processor path, if None will load from default config
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
**kwargs: Additional arguments
Returns:
Qwen2_5_VLMoEForAction: Loaded model instance
"""
# Load model components from pretrained path
config_path = os.path.join(pretrained_model_path, "config.json")
config = cls.config_class.from_pretrained(config_path)
processor = AutoProcessor.from_pretrained(pretrained_model_path, use_fast=True)
if config is None:
config = cls.config_class.from_pretrained(
pretrained_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
strict=strict,
**kwargs,
)
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
if action_tokenizer_path is not None:
processor.action_processor = AutoProcessor.from_pretrained(
action_tokenizer_path, trust_remote_code=True
@@ -312,18 +304,37 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Resize token embeddings to match processor tokenizer vocabulary size
model.resize_token_embeddings(len(processor.tokenizer))
# Load model state dict from safetensors file
safetensor_files = glob.glob(
os.path.join(pretrained_model_path, "*.safetensors")
# Try to load the model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
from safetensors.torch import load_file
sd = load_file(resolved_file)
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
print("Returning model without loading pretrained weights")
return model
state_dict = {}
for file in safetensor_files:
sd = load_file(file, device="cpu")
# filter normalizer statistic params
del_keys = []
for key in sd.keys():
if "action_preprocessor.normalizer" in key:
print(f"filter load model weight {key}")
del_keys.append(key)
for key in del_keys:
del sd[key]
@@ -730,7 +741,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
dataset_names: Optional[str] = None,
dof_mask: Optional[torch.FloatTensor] = None,
agent_pos_mask: Optional[torch.FloatTensor] = None,
**kwargs,
@@ -763,7 +773,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
rope_deltas (torch.LongTensor, optional): RoPE position deltas
cache_position (torch.LongTensor, optional): Cache position indices
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
dataset_names (str, optional): Names of datasets in the current batch
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens
agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data
**kwargs: Additional keyword arguments
@@ -872,7 +881,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
)
proprioception = self.action_preprocessor.proprioception_proj(
proprioception,
dataset_names,
agent_pos_mask,
use_history=proprioception.shape[1] > 1,
)
@@ -909,7 +917,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
)
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
noisy_action_emb, flow = self.action_preprocessor(
action_chunk, dataset_names, dof_mask
action_chunk, dof_mask
)
mask = input_ids == self.action_token_id_set["action_token_id"]
mask_unsqueezed = mask.unsqueeze(-1)
@@ -953,18 +961,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Compute losses if labels are provided
if labels is not None:
loss = 0
action_accuracy = 0
unique_datasets_name = list(set(dataset_names))
# Initialize per-dataset loss tracking dictionaries
channel_loss_dict = {
dataset_name: torch.tensor(0.0, device=logits.device)
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
}
channel_loss_count_dict = {
dataset_name: torch.tensor(0, device=logits.device)
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
}
# Compute standard cross-entropy loss for language modeling
shift_logits = logits[..., :-1, :].contiguous()
@@ -982,22 +978,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
else torch.tensor(0.0, device=shift_logits.device)
)
# Compute per-dataset channel losses
_cross_entropy_loss = _cross_entropy_loss.view(batch_size, seq_length - 1)
non_ignored_mask = non_ignored_mask.view(batch_size, seq_length - 1)
for dataset_name_i in unique_datasets_name:
dataset_mask = torch.tensor(
[name == dataset_name_i for name in dataset_names],
device=logits.device,
)
combined_mask = dataset_mask.unsqueeze(1) & non_ignored_mask
channel_loss_dict[dataset_name_i] = (
_cross_entropy_loss[combined_mask].sum()
if combined_mask.any()
else torch.tensor(0.0, device=shift_logits.device)
)
channel_loss_count_dict[dataset_name_i] += combined_mask.sum()
# Add cross-entropy loss to total loss if valid
if not torch.isnan(cross_entropy_loss):
loss += cross_entropy_loss
@@ -1005,20 +985,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
with torch.no_grad():
cross_entropy_loss.detach()
# Compute action token prediction accuracy
shift_logits = logits[..., :-1, :].contiguous()
action_preds = shift_logits.argmax(dim=-1)
shift_labels = labels[..., 1:].contiguous()
if self.use_fast_tokenizer:
action_mask = (
shift_labels > self.action_token_id_set["fast_action_token_list"][0]
)
correct_preds = (action_preds == shift_labels) & action_mask
action_accuracy = (
correct_preds.sum().float() / action_mask.sum().float()
)
channel_loss_dict["action_accuracy"] = action_accuracy
if action_chunk is not None:
action_mask = input_ids == self.action_token_id_set["action_token_id"]
if action_mask.any():
@@ -1101,7 +1067,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
num_inference_timesteps: Optional[int] = 10,
dataset_names: Optional[str] = None,
dof_mask: Optional[torch.FloatTensor] = None,
agent_pos_mask: Optional[torch.FloatTensor] = None,
re_generate: bool = False,
@@ -1140,7 +1105,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
cache_position (torch.LongTensor, optional): Cache position indices
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
num_inference_timesteps (int, optional): Number of diffusion inference steps
dataset_names (str, optional): Dataset names for normalization
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask
agent_pos_mask (torch.FloatTensor, optional): Agent position mask
re_generate (bool, optional): Whether to use sampling for regeneration
@@ -1239,7 +1203,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
)
proprio_embed = self.action_preprocessor.proprioception_proj(
proprioception,
dataset_names,
agent_pos_mask,
use_history=proprioception.shape[1] > 1,
)
@@ -1339,7 +1302,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"dof_mask": dof_mask,
"agent_pos_mask": agent_pos_mask,
"proprioception": proprioception,
"dataset_names": dataset_names,
}
# Generate output tokens
@@ -1545,7 +1507,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
dataset_names=None,
proprioception=None,
dof_mask=None,
agent_pos_mask=None,
@@ -1572,7 +1533,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
image_grid_thw: Image grid dimensions
video_grid_thw: Video grid dimensions
second_per_grid_ts: Time interval per temporal grid
dataset_names: Dataset names for processing
proprioception: Proprioceptive sensor data
dof_mask: Degrees of freedom mask
agent_pos_mask: Agent position mask
@@ -1677,7 +1637,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"cache_position": cache_position,
"second_per_grid_ts": second_per_grid_ts,
"proprioception": proprioception,
"dataset_names": dataset_names,
"dof_mask": dof_mask,
"agent_pos_mask": agent_pos_mask,
}
@@ -1866,150 +1825,14 @@ class WallXPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
# Initialize VLM wrapper
self.model = Qwen2_5_VLMoEForAction(config)
# Initialize the wall-x model
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
self.model.to(config.device)
# Convert to bfloat16 for Flash Attention compatibility
self.model.to_bfloat16_for_selected_params()
self.reset()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
) -> T:
"""
Load WallXPolicy from a pretrained model path.
Args:
pretrained_name_or_path: Path to pretrained model or model identifier
config: Optional configuration object
force_download: Force download even if cached
resume_download: Resume interrupted download
proxies: Proxy configuration
token: Authentication token
cache_dir: Cache directory path
local_files_only: Only use local files
revision: Model revision
strict: Strict loading of state dict
**kwargs: Additional arguments
Returns:
WallXPolicy: Loaded policy instance
"""
print(
"Loading Wall-X model for cross-embodiment robotic control.\n"
"This implementation integrates Qwen2.5-VL with flow matching for action prediction."
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Use provided config if available, otherwise load from pretrained path
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
# Initialize model without loading weights
model = cls(config, **kwargs)
# Load and remap the state dict
try:
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
original_state_dict = load_file(resolved_file)
print("✓ Loaded state dict from model.safetensors")
except Exception:
print(f"Could not load state dict: {e}")
print("Returning model without loading pretrained weights")
return model
# Filter out normalizer statistics if present
filtered_state_dict = {}
for key, value in original_state_dict.items():
if "action_preprocessor.normalizer" not in key:
filtered_state_dict[key] = value
else:
print(f"Filtered key: {key}")
# Add "model." prefix for keys that don't have it
remapped_state_dict = {}
remap_count = 0
for key, value in filtered_state_dict.items():
if not key.startswith("model."):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
else:
remapped_state_dict[key] = value
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
return model
def reset(self):
"""Reset action queue."""
self._queues = {
@@ -2022,88 +1845,50 @@ class WallXPolicy(PreTrainedPolicy):
def preprocess_inputs(
self,
batch: List[Dict[str, Any]],
config: Dict[str, Any],
dataload_config: Dict[str, Any],
lerobot_config: Dict[str, Any],
processor: Any,
action_tokenizer: Optional[Any] = None,
camera_keys: Optional[List[str]] = None,
batch: Dict[str, Any],
) -> BatchFeature:
"""
Convert a batch of LeRobot dataset items to Wall-X model input format.
This is the batch version of convert_lerobot_to_wallx_format, processing multiple
samples together for efficient batched inference or training.
This processes a batched dictionary where tensors have batch dimension first.
Args:
batch: List of items from LeRobot dataset
config: Model configuration dict
dataload_config: Data loading configuration
norm_stats: Normalization statistics
lerobot_config: LeRobot config containing 'repo_id'
processor: Hugging Face processor for tokenization
action_tokenizer: Optional action tokenizer
camera_keys: List of camera keys in the dataset
batch: Dictionary with batched tensors:
- "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim)
- "action": (batch_size, chunk_size, action_dim)
- "observation.images.<key>": (batch_size, C, H, W)
- "task": List[str] of length batch_size
Returns:
BatchFeature containing batched model inputs
Example:
>>> batch = [dataset[i] for i in range(8)]
>>> result = preprocess_inputs(
... batch=batch,
... config=config,
... dataload_config=dataload_config,
... norm_stats=norm_stats,
... lerobot_config={'repo_id': 'lerobot/aloha_mobile_cabinet'},
... processor=processor,
... )
"""
repo_id = lerobot_config["repo_id"]
use_fast_tokenizer = config.get("use_fast_tokenizer", False)
use_fast_tokenizer = self.config.use_fast_tokenizer
# Get key mappings
cam_key_mapping = KEY_MAPPINGS[repo_id]["camera"]
state_key = KEY_MAPPINGS[repo_id]["state"]
action_key = KEY_MAPPINGS[repo_id]["action"]
# Build data config
data_config = X2RDataProcessingConfig().update(
train_test_split=dataload_config.get("train_test_split", 0.95),
split_seed=dataload_config.get("split_seed", 42),
predict_action_keys=dataload_config.get("predict_action_keys", []),
obs_action_keys=dataload_config.get("obs_action_keys", []),
resolution=dataload_config.get("resolution", None),
priority_order=dataload_config.get("priority_order", None),
)
if camera_keys is None:
camera_keys = list(cam_key_mapping.keys())
# Get batch size from state tensor
batch_size = batch[OBS_STATE].shape[0]
# ==================== PROCESS ALL SAMPLES ====================
all_image_inputs = []
all_texts = []
all_agent_pos = []
all_actions = []
all_frame_indices = []
for data in batch:
# Find image keys in batch
img_keys = [key for key in self.config.image_features if key in batch]
for i in range(batch_size):
# Vision preprocessing per sample
processed_frames = []
orig_height, orig_width = None, None
resized_height, resized_width = None, None
for key in camera_keys:
current_obs = data[key].clone()
for key in img_keys:
current_obs = batch[key][i].clone() # (C, H, W)
if current_obs.dim() == 3:
current_obs = current_obs.permute(1, 2, 0)
current_obs = current_obs.permute(1, 2, 0) # (H, W, C)
img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy())
orig_width, orig_height = img_pil.size
cam_name = cam_key_mapping.get(key, key)
target_size = data_config.resolution.get(cam_name, -1)
target_size = RESOLUTION
if target_size != -1:
if orig_width > orig_height:
new_width = target_size
@@ -2117,9 +1902,9 @@ class WallXPolicy(PreTrainedPolicy):
resized_height, resized_width = smart_resize(
current_height,
current_width,
factor=data_config.image_factor,
min_pixels=data_config.min_pixels,
max_pixels=data_config.max_pixels,
factor=IMAGE_FACTOR,
min_pixels=MIN_PIXELS,
max_pixels=MAX_PIXELS,
)
resized_img = img_pil.resize((resized_width, resized_height))
processed_frames.append(resized_img)
@@ -2127,33 +1912,30 @@ class WallXPolicy(PreTrainedPolicy):
all_image_inputs.append(processed_frames)
# Text preprocessing
frame_index = data["frame_index"]
instruction_info = {"instruction": data["task"]}
task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"]
instruction_info = {"instruction": task_text}
frame_index = batch["frame_index"][i] if "frame_index" in batch else 0
complete_text, _ = get_wallx_normal_text(
instruction_info,
dataload_config.get("action_horizon", 33) - 1,
self.config.chunk_size,
frame_index,
data_config.priority_order,
cam_key_mapping,
generate_subtask_ratio=data_config.generate_subtask_ratio,
PRIORITY_ORDER,
img_keys,
generate_subtask_ratio=GENERATE_SUBTASK_RATIO,
)
text = process_grounding_points(
complete_text, orig_height, orig_width, resized_height, resized_width,
data_config.model_type
MODEL_TYPE
)
all_texts.append(text)
# Collect raw values
all_agent_pos.append(data[state_key])
all_actions.append(data[action_key])
all_frame_indices.append(frame_index)
# Stack agent_pos
agent_pos = torch.stack(all_agent_pos)
# ==================== PROCESS AGENT POS ====================
agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
if agent_pos.dim() == 2:
agent_pos = agent_pos.unsqueeze(1)
agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim)
agent_pos_mask = (~torch.isnan(agent_pos)).float()
agent_pos = agent_pos.nan_to_num(nan=0.0)
@@ -2161,15 +1943,15 @@ class WallXPolicy(PreTrainedPolicy):
pad_size = 20 - agent_pos.shape[-1]
agent_pos = torch.cat([
agent_pos,
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size)
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device)
], dim=-1)
agent_pos_mask = torch.cat([
agent_pos_mask,
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size)
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device)
], dim=-1)
# Stack actions
action = torch.stack(all_actions)
# ==================== PROCESS ACTIONS ====================
action = batch[ACTION] # (batch_size, chunk_size, action_dim)
if action.dim() == 2:
action = action.unsqueeze(1)
dof_mask = (~torch.isnan(action)).float()
@@ -2179,36 +1961,35 @@ class WallXPolicy(PreTrainedPolicy):
pad_size = 20 - action.shape[-1]
action = torch.cat([
action,
torch.zeros(action.shape[0], action.shape[1], pad_size)
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
], dim=-1)
dof_mask = torch.cat([
dof_mask,
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size)
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
], dim=-1)
# ==================== ACTION TOKEN REPLACEMENT ====================
all_texts = replace_action_token(
all_texts,
action,
action_tokenizer if use_fast_tokenizer else None,
[repo_id] * len(batch),
self.model.action_tokenizer if use_fast_tokenizer else None,
dof_mask,
)
# ==================== TOKENIZATION ====================
inputs = preprocesser_call(
processor=processor,
processor=self.model.processor,
text=all_texts,
images=all_image_inputs,
videos=None,
padding=True,
truncation=True,
return_tensors="pt",
max_length=dataload_config.get("max_length", 768),
max_length=TOKENIZER_MAX_LENGTH,
)
# ==================== ADDITIONAL INPUTS ====================
action_token_id = processor.tokenizer.convert_tokens_to_ids("<|action|>")
action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>")
moe_token_types = inputs.input_ids == action_token_id
inputs["proprioception"] = agent_pos
@@ -2216,11 +1997,13 @@ class WallXPolicy(PreTrainedPolicy):
inputs["action_chunk"] = action
inputs["dof_mask"] = dof_mask
inputs["moe_token_types"] = moe_token_types
inputs["dataset_names"] = [repo_id] * len(batch)
inputs["frame_index"] = torch.stack([
torch.tensor(fi) if not isinstance(fi, torch.Tensor) else fi
for fi in all_frame_indices
])
inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device)
# Move all tensors to the correct device
device = self.config.device
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.to(device)
return inputs
@@ -2232,37 +2015,19 @@ class WallXPolicy(PreTrainedPolicy):
batch: Dictionary containing preprocessed inputs from preprocess_inputs()
Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw,
proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types,
dataset_names, etc.
etc.
Returns:
tuple: (loss, loss_dict)
"""
batch = self.preprocess_inputs(
batch,
self.config,
self.dataload_config,
self.lerobot_config,
self.processor,
self.action_tokenizer,
self.camera_keys
)
# Call the underlying model's forward with mode="train"
outputs = self.model(
mode="train",
input_ids=batch.get("input_ids"),
attention_mask=batch.get("attention_mask"),
pixel_values=batch.get("pixel_values"),
image_grid_thw=batch.get("image_grid_thw"),
pixel_values_videos=batch.get("pixel_values_videos"),
video_grid_thw=batch.get("video_grid_thw"),
proprioception=batch.get("proprioception"),
agent_pos_mask=batch.get("agent_pos_mask"),
action_chunk=batch.get("action_chunk"),
dof_mask=batch.get("dof_mask"),
moe_token_types=batch.get("moe_token_types"),
dataset_names=batch.get("dataset_names"),
labels=batch.get("labels", batch.get("input_ids")), # Use input_ids as labels if not provided
**batch,
mode="train"
)
# Extract losses from output
@@ -2292,16 +2057,10 @@ class WallXPolicy(PreTrainedPolicy):
batch = self.preprocess_inputs(
batch,
self.config,
self.dataload_config,
self.lerobot_config,
self.processor,
self.action_tokenizer,
self.camera_keys
)
if self.config.prediction_mode == "diffusion":
actions = self.model(
output = self.model(
**batch,
action_dim=self.config.max_action_dim,
pred_horizon=self.config.chunk_size,
@@ -2309,9 +2068,9 @@ class WallXPolicy(PreTrainedPolicy):
predict_mode="diffusion"
)
elif self.config.prediction_mode == "fast":
actions = self.model(
output = self.model(
**batch,
action_dim=self.config.action_feature.shape[0],
action_dim=self.config.output_features["action"].shape[0],
pred_horizon=self.config.chunk_size,
mode="predict",
predict_mode="fast"
@@ -2319,8 +2078,12 @@ class WallXPolicy(PreTrainedPolicy):
else:
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
# Unpad actions
actions = actions[:, :, :self.config.action_feature.shape[0]]
# Extract action tensor from output dictionary
actions = output["predict_action"]
# Unpad actions to actual action dimension
action_dim = self.config.output_features["action"].shape[0]
actions = actions[:, :, :action_dim]
return actions
@@ -29,13 +29,10 @@ from lerobot.processor import (
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors(
config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -49,7 +46,6 @@ def make_wall_x_pre_post_processors(
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations
2. Adding a batch dimension
3. Tokenizing language task descriptions
4. Normalizing input and output features based on dataset statistics
5. Moving all data to the specified device
@@ -65,25 +61,10 @@ def make_wall_x_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines
"""
# Try to use Qwen processor if available
try:
from transformers import AutoProcessor
tokenizer_name = config.vlm_model_name
qwen_available = True
except ImportError:
tokenizer_name = "Qwen/Qwen2-VL-2B-Instruct" # Fallback
qwen_available = False
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
WallXTaskProcessor(), # Process task description
TokenizerProcessorStep(
tokenizer_name=tokenizer_name,
padding="max_length",
padding_side="right",
max_length=config.tokenizer_max_length,
),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -152,30 +133,3 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register(name="wall_x_image_processor")
class WallXImageProcessor(ComplementaryDataProcessorStep):
"""
Image processor for Wall-X using Qwen-VL vision processing.
This handles image formatting according to Qwen-VL requirements.
"""
def __init__(self):
super().__init__()
try:
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
self.available = True
except ImportError:
self.available = False
def complementary_data(self, complementary_data):
# Image processing is handled by the VLM processor
return complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+29 -18
View File
@@ -33,10 +33,8 @@ from transformers import BatchFeature
from lerobot.policies.wall_x.constant import (
CAMERA_NAME_MAPPING,
FREQUENCY_MAPPING,
KEY_MAPPINGS,
MULTIMODAL_DATASET_NAMES,
)
from lerobot.utils.constants import OBS_IMAGES
@dataclass
@@ -457,7 +455,7 @@ def get_wallx_normal_text(
action_chunk_size: int,
frame_idx: int,
priority_order: Optional[OrderedDict] = None,
cam_mapping: Optional[Dict[str, str]] = None,
img_keys: Optional[List[str]] = None,
generate_subtask_ratio: float = 0.0,
) -> Tuple[str, bool]:
"""Construct complete multimodal prompt text for Wall-X model.
@@ -474,7 +472,7 @@ def get_wallx_normal_text(
action_chunk_size: Number of action tokens to generate
frame_idx: Current frame index
priority_order: Priority order for instruction sampling
cam_mapping: Camera name mapping dictionary
img_keys: List of image keys
generate_subtask_ratio: Probability of generating subtask instead of actions
Returns:
@@ -497,10 +495,10 @@ def get_wallx_normal_text(
# User request with observation
user_request = f"{role_start_symbol}user\nObservation:"
if cam_mapping:
for _, cam_name in cam_mapping.items():
view_name = CAMERA_NAME_MAPPING.get(cam_name, cam_name)
user_request += f" {view_name}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
if img_keys:
img_keys = img_key_mapping(img_keys)
for key in img_keys:
user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
user_request += "\nInstruction:"
# Get frame-specific instruction
@@ -543,6 +541,27 @@ def get_wallx_normal_text(
complete_text = prologue + user_message + assistant_output
return complete_text, generate_subtask
def img_key_mapping(img_keys: List[str]) -> List[str]:
"""Map image keys to camera names.
Args:
img_keys: List of image keys
Returns:
List of camera names
"""
processed_img_keys = []
for key in img_keys:
key = key.replace(OBS_IMAGES + ".", "")
if key in CAMERA_NAME_MAPPING:
key = CAMERA_NAME_MAPPING[key]
else:
if 'view' in key:
key = key.replace('_', ' ')
else:
key = key + " view"
processed_img_keys.append(key)
return processed_img_keys
def get_action_tokens(
normalized_actions: Union[torch.Tensor, List], action_tokenizer
@@ -599,7 +618,6 @@ def replace_action_token(
text: List[str],
norm_action: Optional[torch.Tensor],
action_tokenizer,
dataset_names: List[str],
dof_masks: Optional[torch.Tensor] = None,
) -> List[str]:
"""Replace action placeholders in text with actual action tokens.
@@ -615,17 +633,10 @@ def replace_action_token(
List of text strings with action tokens replaced
"""
# Filter out multimodal dataset names
dataset_names = [
name for name in dataset_names if name not in MULTIMODAL_DATASET_NAMES
]
# Get required action chunk sizes
required_chunk_sizes = [32 for name in dataset_names]
if action_tokenizer is not None and norm_action is not None:
# Extract actions based on chunk sizes and DOF masks
norm_action = [
action[: required_chunk_sizes[i], dof_masks[i, 0].bool()]
action[: 32, dof_masks[i, 0].bool()]
for i, action in enumerate(norm_action)
]
+2
View File
@@ -0,0 +1,2 @@
# Wall-X policy tests
+126
View File
@@ -0,0 +1,126 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local Wall-X installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.wall_x import ( # noqa: E402
WallXConfig,
WallXPolicy,
make_wall_x_pre_post_processors, # noqa: E402
)
from lerobot.utils.random_utils import set_seed # noqa: E402
def test_policy_instantiation():
# Create config
set_seed(42)
config = WallXConfig(device='cuda')
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
),
"observation.images.face_view": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.face_view": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy
policy = WallXPolicy(config)
preprocessor, postprocessor = make_wall_x_pre_post_processors(config=config, dataset_stats=dataset_stats)
# Test forward pass with dummy data
batch_size = 1
device = config.device
batch = {
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.face_view": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size,
}
batch = preprocessor(batch)
try:
loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
try:
with torch.no_grad():
action = policy.select_action(batch)
action = postprocessor(action)
print(f"Action: {action}")
print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e:
print(f"Action prediction failed: {e}")
raise
def test_config_creation():
"""Test policy config creation through factory."""
try:
config = make_policy_config(
policy_type="wall_x",
)
print("Config created successfully through factory")
print(f" Config type: {type(config).__name__}")
except Exception as e:
print(f"Config creation failed: {e}")
raise
if __name__ == "__main__":
test_policy_instantiation()
test_config_creation()