reduce to least config and params & pass lerobot basic test

This commit is contained in:
Geoffrey19
2025-12-04 15:38:04 +08:00
committed by Michel Aractingi
parent feebca050a
commit d10d3ef251
7 changed files with 349 additions and 832 deletions
@@ -13,13 +13,11 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("wall_x") @PreTrainedConfig.register_subclass("wall_x")
@@ -34,48 +32,15 @@ class WallXConfig(PreTrainedConfig):
This config supports multi-modal learning with vision, language, and action data. This config supports multi-modal learning with vision, language, and action data.
""" """
# ==================== Model and Paths Configuration ====================
# Logging
log_name: str = "wall_x_training"
log_project: str = "vla_training"
model_type: str = "wall-oss"
# Pretrained model paths
pretrained_wallx_path: str | None = None # Path to pretrained Wall-X model
save_path: str | None = None # Path to save checkpoints
processor_path: str | None = None # Path to processor (defaults to pretrained_wallx_path)
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
# Tokenizer settings
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
# ==================== Profiling Configuration ====================
profile: bool = False
profile_save_path: str | None = None
profile_wait_iters: int = 10
profile_warmup_iters: int = 5
profile_active_iters: int = 2
# ==================== Training Hyperparameters ====================
num_warmup_steps: int = 100
num_training_steps: int = 64000000
learning_rate: float = 5e-5
min_lr: float = 5e-5
num_epoch: int = 100
gradient_accumulation_steps: int = 32
batch_size_per_gpu: int = 8
padding_side: str = "left"
epoch_save_interval: int = 10
# Training optimization
fsdp2: bool = False
torch_compile: bool = False
# ==================== Input / Output Structure ==================== # ==================== Input / Output Structure ====================
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 32 # action_horizon in wall-x chunk_size: int = 32 # action_horizon in wall-x
n_action_steps: int = 32 n_action_steps: int = 32
# Action dimension - wall-x uses 20
max_action_dim: int = 20
max_state_dim: int = 20 # For proprioception
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY, "VISUAL": NormalizationMode.IDENTITY,
@@ -84,101 +49,17 @@ class WallXConfig(PreTrainedConfig):
} }
) )
# Action dimension - wall-x uses hardcoded 20 # ==================== Action Prediction ====================
max_action_dim: int = 20 # Pretrained model paths
max_state_dim: int = 20 # For proprioception pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
# Image preprocessing # Action prediction mode: "diffusion" or "fast"
resize_imgs_with_padding: tuple[int, int] | None = None # wall-x uses Qwen processor prediction_mode: str = "diffusion"
# Tokenizer # Tokenizer settings
tokenizer_max_length: int = 256 use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
# ==================== Model Architecture ====================
vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
load_vlm_weights: bool = True
# Vision config
vision_config: dict = field(default_factory=lambda: {
"depth": 32,
"hidden_size": 3584,
"hidden_act": "silu",
"intermediate_size": 3420,
"num_heads": 16,
"patch_size": 14,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
"window_size": 112,
"out_hidden_size": 3584,
})
# Language model config
hidden_size: int = 3584 # 8192 for 7B model
intermediate_size: int = 18944 # 29568 for 7B model
num_hidden_layers: int = 36 # 80 for 7B model
num_attention_heads: int = 28 # 64 for 7B model
num_key_value_heads: int = 4 # 8 for 7B model
vocab_size: int = 152064
# ==================== Action Prediction ====================
# Action prediction mode: "flow" or "fast"
prediction_mode: str = "flow"
# Flow matching parameters
noise_scheduler: dict = field(default_factory=lambda: {
"beta_alpha": 1.5, # Beta distribution concentration1
"beta_beta": 1.0, # Beta distribution concentration0
"s": 0.999, # Scaling factor for time
})
# Decoding parameters
num_inference_timesteps: int = 10 # Number of ODE solver steps
ode_solver_method: str = "euler" # ODE solver method
# ==================== Robot Configuration ====================
# Degrees of freedom configuration - defines action space
dof_config: dict = field(default_factory=lambda: {
"left_ee_pos": 3,
"left_ee_rot": 3,
"left_gripper": 1,
"right_ee_pos": 3,
"right_ee_rot": 3,
"right_gripper": 1,
})
# Proprioception configuration (typically mirrors dof_config)
agent_pos_config: dict = field(default_factory=lambda: {
"left_ee_pos": 3,
"left_ee_rot": 3,
"left_gripper": 1,
"right_ee_pos": 3,
"right_ee_rot": 3,
"right_gripper": 1,
})
# Customized robot configuration
enable_customized_robot_config: bool = False
customized_robot_config: dict = field(default_factory=lambda: {
"name": "",
"customized_dof_config": {},
"customized_agent_pos_config": {},
})
# Normalization statistics path
norm_stats_path: str | None = None
# ==================== MoE Configuration ====================
num_experts: int = 4
attention_moe: bool = False
mlp_moe: bool = False
# ==================== Finetuning Settings ====================
freeze_vision_encoder: bool = True
train_expert_only: bool = False # wall-x trains more components
train_action_head: bool = True
# Cache
use_cache: bool = True
# ==================== Optimizer Presets ==================== # ==================== Optimizer Presets ====================
optimizer_lr: float = 2e-5 optimizer_lr: float = 2e-5
@@ -191,44 +72,6 @@ class WallXConfig(PreTrainedConfig):
scheduler_decay_steps: int = 100000 scheduler_decay_steps: int = 100000
scheduler_decay_lr: float = 1e-6 scheduler_decay_lr: float = 1e-6
# ==================== Dataset Configuration ====================
# Dataset-specific normalization statistics
action_statistics: dict = field(default_factory=dict)
# Data configuration
data_config: dict = field(default_factory=lambda: {
"use_lerobot": True,
"lerobot_config": {
"repo_id": "",
"root": None,
"episodes": None,
"image_transforms": None,
"delta_timestamps": None,
"tolerance_s": 1e-4,
"revision": None,
"force_cache_sync": False,
"download_videos": True,
"video_backend": None,
},
"action_horizon": 32,
"train_test_split": 0.95,
"obs_action_keys": [],
"predict_action_keys": [],
"resolution": {
"face_view": 256,
"left_wrist_view": 256,
"right_wrist_view": 256,
"move1_view": 256,
"move2_view": 256,
"top_view": 256,
"wall_view": 256,
"multi_modal": 256,
},
})
# ==================== Resume Configuration ====================
resume_config: dict | None = field(default_factory=lambda: None)
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@@ -239,243 +82,55 @@ class WallXConfig(PreTrainedConfig):
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
) )
if self.prediction_mode not in ["flow", "fast"]: if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError( raise ValueError(
f"prediction_mode must be 'flow' or 'fast', got {self.prediction_mode}" f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
# Validate dof_config total doesn't exceed max_action_dim
total_dof = sum(self.dof_config.values())
if total_dof > self.max_action_dim:
raise ValueError(
f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})"
) )
# Sync prediction_mode with use_fast_tokenizer # Sync prediction_mode with use_fast_tokenizer
if self.use_fast_tokenizer: if self.use_fast_tokenizer:
self.prediction_mode = "fast" self.prediction_mode = "fast"
else: else:
self.prediction_mode = "flow" self.prediction_mode = "diffusion"
def get_train_config(self) -> dict: def validate_features(self) -> None:
""" """Validate and set up input/output features."""
Extract the complete train_config dictionary matching the YAML training configuration format. image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"Wall-X policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
This method constructs the full train_config from WallXConfig fields, suitable for if "observation.state" not in self.input_features:
training scripts and Qwen2_5_VLMoEForAction.from_pretrained. state_feature = PolicyFeature(
type=FeatureType.STATE,
Returns: shape=(self.max_state_dim,), # Padded to max_state_dim
dict: Complete training configuration matching YAML structure. )
""" self.input_features["observation.state"] = state_feature
# Build customized_robot_config
if self.enable_customized_robot_config and self.customized_robot_config:
customized_robot_config = {
"name": self.customized_robot_config.get("name", ""),
"customized_dof_config": self.customized_robot_config.get(
"customized_dof_config", self.dof_config
),
"customized_agent_pos_config": self.customized_robot_config.get(
"customized_agent_pos_config", self.agent_pos_config
),
}
else: else:
customized_robot_config = { state_shape = self.input_features["observation.state"].shape
"name": self.data_config.get("lerobot_config", {}).get("repo_id", ""), state_dim = state_shape[0] if state_shape else 0
"customized_dof_config": self.dof_config, if state_dim > self.max_state_dim:
"customized_agent_pos_config": self.agent_pos_config, raise ValueError(
} f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
f"Either reduce state dimension or increase max_state_dim in config."
)
train_config = { if "action" not in self.output_features:
# Model and paths configuration action_feature = PolicyFeature(
"log_name": self.log_name, type=FeatureType.ACTION,
"log_project": self.log_project, shape=(self.max_action_dim,), # Padded to max_action_dim
"model_type": self.model_type, )
"pretrained_wallx_path": self.pretrained_wallx_path, self.output_features["action"] = action_feature
"save_path": self.save_path, else:
"use_fast_tokenizer": self.use_fast_tokenizer, action_shape = self.output_features["action"].shape
"action_tokenizer_path": self.action_tokenizer_path, action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
# Profiling configuration raise ValueError(
"profile": self.profile, f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
"profile_save_path": self.profile_save_path, f"Either reduce action dimension or increase max_action_dim in config."
"profile_wait_iters": self.profile_wait_iters, )
"profile_warmup_iters": self.profile_warmup_iters,
"profile_active_iters": self.profile_active_iters,
# Training hyperparameters
"num_warmup_steps": self.num_warmup_steps,
"num_training_steps": self.num_training_steps,
"learning_rate": self.learning_rate,
"min_lr": self.min_lr,
"num_epoch": self.num_epoch,
"gradient_accumulation_steps": self.gradient_accumulation_steps,
"batch_size_per_gpu": self.batch_size_per_gpu,
"padding_side": self.padding_side,
"epoch_save_interval": self.epoch_save_interval,
# Training optimization
"FSDP2": self.fsdp2,
"torch_compile": self.torch_compile,
# Robot configuration
"dof_config": self.dof_config,
"agent_pos_config": self.agent_pos_config,
# Normalization stats
"norm_stats_path": self.norm_stats_path,
# Customized robot config
"enable_customized_robot_config": self.enable_customized_robot_config,
"customized_robot_config": customized_robot_config,
# Resume configuration
"resume": self.resume_config,
# Data configuration
"data": self.data_config,
}
return train_config
def get_dataload_config(self) -> dict:
"""
Extract data loading configuration from config.
Returns:
dict: Data loading configuration for preprocessing.
"""
return {
"action_horizon": self.data_config.get("action_horizon", self.chunk_size),
"train_test_split": self.data_config.get("train_test_split", 0.95),
"split_seed": 42,
"predict_action_keys": self.data_config.get("predict_action_keys", []),
"obs_action_keys": self.data_config.get("obs_action_keys", []),
"resolution": self.data_config.get("resolution", {}),
"priority_order": None,
"max_length": self.tokenizer_max_length,
}
def get_lerobot_config(self) -> dict:
"""
Extract LeRobot dataset configuration.
Returns:
dict: LeRobot dataset configuration.
"""
return self.data_config.get("lerobot_config", {})
@classmethod
def from_yaml_dict(cls, yaml_dict: dict) -> "WallXConfig":
"""
Create a WallXConfig from a YAML configuration dictionary.
Args:
yaml_dict: Dictionary loaded from YAML training config file.
Returns:
WallXConfig instance with values from YAML.
"""
config_kwargs = {}
# Model and paths
if "log_name" in yaml_dict:
config_kwargs["log_name"] = yaml_dict["log_name"]
if "log_project" in yaml_dict:
config_kwargs["log_project"] = yaml_dict["log_project"]
if "model_type" in yaml_dict:
config_kwargs["model_type"] = yaml_dict["model_type"]
if "pretrained_wallx_path" in yaml_dict:
config_kwargs["pretrained_wallx_path"] = yaml_dict["pretrained_wallx_path"]
if "save_path" in yaml_dict:
config_kwargs["save_path"] = yaml_dict["save_path"]
if "use_fast_tokenizer" in yaml_dict:
config_kwargs["use_fast_tokenizer"] = yaml_dict["use_fast_tokenizer"]
if "action_tokenizer_path" in yaml_dict:
config_kwargs["action_tokenizer_path"] = yaml_dict["action_tokenizer_path"]
# Profiling
if "profile" in yaml_dict:
config_kwargs["profile"] = yaml_dict["profile"]
if "profile_save_path" in yaml_dict:
config_kwargs["profile_save_path"] = yaml_dict["profile_save_path"]
if "profile_wait_iters" in yaml_dict:
config_kwargs["profile_wait_iters"] = yaml_dict["profile_wait_iters"]
if "profile_warmup_iters" in yaml_dict:
config_kwargs["profile_warmup_iters"] = yaml_dict["profile_warmup_iters"]
if "profile_active_iters" in yaml_dict:
config_kwargs["profile_active_iters"] = yaml_dict["profile_active_iters"]
# Training hyperparameters
if "num_warmup_steps" in yaml_dict:
config_kwargs["num_warmup_steps"] = yaml_dict["num_warmup_steps"]
config_kwargs["scheduler_warmup_steps"] = yaml_dict["num_warmup_steps"]
if "num_training_steps" in yaml_dict:
config_kwargs["num_training_steps"] = yaml_dict["num_training_steps"]
config_kwargs["scheduler_decay_steps"] = yaml_dict["num_training_steps"]
if "learning_rate" in yaml_dict:
config_kwargs["learning_rate"] = yaml_dict["learning_rate"]
config_kwargs["optimizer_lr"] = yaml_dict["learning_rate"]
if "min_lr" in yaml_dict:
config_kwargs["min_lr"] = yaml_dict["min_lr"]
config_kwargs["scheduler_decay_lr"] = yaml_dict["min_lr"]
if "num_epoch" in yaml_dict:
config_kwargs["num_epoch"] = yaml_dict["num_epoch"]
if "gradient_accumulation_steps" in yaml_dict:
config_kwargs["gradient_accumulation_steps"] = yaml_dict["gradient_accumulation_steps"]
if "batch_size_per_gpu" in yaml_dict:
config_kwargs["batch_size_per_gpu"] = yaml_dict["batch_size_per_gpu"]
if "padding_side" in yaml_dict:
config_kwargs["padding_side"] = yaml_dict["padding_side"]
if "epoch_save_interval" in yaml_dict:
config_kwargs["epoch_save_interval"] = yaml_dict["epoch_save_interval"]
# Training optimization
if "FSDP2" in yaml_dict:
config_kwargs["fsdp2"] = yaml_dict["FSDP2"]
if "torch_compile" in yaml_dict:
config_kwargs["torch_compile"] = yaml_dict["torch_compile"]
# Robot configuration
if "dof_config" in yaml_dict:
config_kwargs["dof_config"] = yaml_dict["dof_config"]
if "agent_pos_config" in yaml_dict:
config_kwargs["agent_pos_config"] = yaml_dict["agent_pos_config"]
# Normalization stats
if "norm_stats_path" in yaml_dict:
config_kwargs["norm_stats_path"] = yaml_dict["norm_stats_path"]
# Customized robot config
if "enable_customized_robot_config" in yaml_dict:
config_kwargs["enable_customized_robot_config"] = yaml_dict["enable_customized_robot_config"]
if "customized_robot_config" in yaml_dict:
config_kwargs["customized_robot_config"] = yaml_dict["customized_robot_config"]
# Resume config
if "resume" in yaml_dict:
config_kwargs["resume_config"] = yaml_dict["resume"]
# Data configuration
if "data" in yaml_dict:
data = yaml_dict["data"]
data_config = {
"use_lerobot": data.get("use_lerobot", True),
"action_horizon": data.get("action_horizon", 32),
"train_test_split": data.get("train_test_split", 0.95),
"obs_action_keys": data.get("obs_action_keys", []),
"predict_action_keys": data.get("predict_action_keys", []),
"resolution": data.get("resolution", {}),
}
if "lerobot_config" in data:
data_config["lerobot_config"] = data["lerobot_config"]
config_kwargs["data_config"] = data_config
# Set chunk_size from action_horizon
if "action_horizon" in data:
config_kwargs["chunk_size"] = data["action_horizon"]
config_kwargs["n_action_steps"] = data["action_horizon"]
return cls(**config_kwargs)
def get_optimizer_preset(self) -> AdamWConfig: def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig( return AdamWConfig(
@@ -496,7 +151,7 @@ class WallXConfig(PreTrainedConfig):
@property @property
def observation_delta_indices(self) -> list: def observation_delta_indices(self) -> list:
return [0] return None
@property @property
def action_delta_indices(self) -> list: def action_delta_indices(self) -> list:
+13 -7
View File
@@ -16,15 +16,9 @@
""" """
Wall-X Constants and Configuration Data. Wall-X Constants and Configuration Data.
Contains dataset names, key mappings, frequency mappings, and action statistics
for cross-embodiment robotic control.
""" """
from pathlib import Path from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION
# Add wall-x repo to path if available
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
CAMERA_NAME_MAPPING = { CAMERA_NAME_MAPPING = {
"face_view": "front view", "face_view": "front view",
@@ -35,3 +29,15 @@ CAMERA_NAME_MAPPING = {
"wall_view": "wall view", "wall_view": "wall view",
"top_view": "top view", "top_view": "top view",
} }
RESOLUTION = 256
# Parameters for preprocessing
MAX_PIXELS = 16384 * 28 * 28
MIN_PIXELS = 4 * 28 * 28
IMAGE_FACTOR = 28
PRIORITY_ORDER = None
GENERATE_SUBTASK_RATIO = 0.0
MODEL_TYPE = "qwen2_5"
TOKENIZER_MAX_LENGTH = 768
+128 -365
View File
@@ -38,6 +38,7 @@ import builtins
import glob import glob
import math import math
import os import os
from os import PathLike
import sys import sys
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
@@ -57,16 +58,8 @@ from torchdiffeq import odeint
from transformers import AutoConfig, AutoProcessor from transformers import AutoConfig, AutoProcessor
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.cache_utils import ( from transformers.cache_utils import (
Cache,
DynamicCache,
SlidingWindowCache,
StaticCache, StaticCache,
) )
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_torchdynamo_compiling, logging from transformers.utils import is_torchdynamo_compiling, logging
from transformers import AutoProcessor, BatchFeature from transformers import AutoProcessor, BatchFeature
from qwen_vl_utils.vision_process import smart_resize from qwen_vl_utils.vision_process import smart_resize
@@ -80,29 +73,18 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
from lerobot.policies.wall_x.utils import * from lerobot.policies.wall_x.utils import *
from lerobot.policies.wall_x.constant import * from lerobot.policies.wall_x.constant import *
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2RMSNorm,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLRotaryEmbedding,
Qwen2_5_VLPreTrainedModel,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
) )
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLDecoderLayer_with_MoE,
Qwen2_5_VLACausalLMOutputWithPast, Qwen2_5_VLACausalLMOutputWithPast,
Qwen2_5_VLMoEModel, Qwen2_5_VLMoEModel,
) )
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Add wall-x repo to path if available
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
if WALL_X_PATH.exists():
sys.path.insert(0, str(WALL_X_PATH))
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embedding for diffusion timesteps.""" """Sinusoidal positional embedding for diffusion timesteps."""
@@ -129,7 +111,7 @@ class ActionHead(nn.Module):
for action sequence prediction. for action sequence prediction.
""" """
def __init__(self, config: WallXConfig): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
@@ -138,10 +120,9 @@ class ActionHead(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Beta distribution for noise scheduling # Beta distribution for noise scheduling
noise_config = config.noise_scheduler self.beta_alpha = 1.5
self.beta_alpha = noise_config.get("beta_alpha", 1.5) self.beta_beta = 1.0
self.beta_beta = noise_config.get("beta_beta", 1.0) self.s = 0.999
self.s = noise_config.get("s", 0.999)
# Sinusoidal timestep embedding # Sinusoidal timestep embedding
self.time_embed = SinusoidalPosEmb(config.hidden_size) self.time_embed = SinusoidalPosEmb(config.hidden_size)
@@ -277,12 +258,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_path, pretrained_name_or_path,
train_config, config=None,
config_path=None,
processor_path=None,
action_tokenizer_path=None, action_tokenizer_path=None,
**kwargs, cache_dir: str | PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
strict: bool = False,
**kwargs: Any
): ):
""" """
Load model from pretrained model path. Load model from pretrained model path.
@@ -290,17 +275,24 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
Args: Args:
pretrained_model_path (str): Model directory path containing model.safetensors file pretrained_model_path (str): Model directory path containing model.safetensors file
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
processor_path (str, optional): Processor path, if None will load from default config
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
**kwargs: Additional arguments **kwargs: Additional arguments
Returns: Returns:
Qwen2_5_VLMoEForAction: Loaded model instance Qwen2_5_VLMoEForAction: Loaded model instance
""" """
# Load model components from pretrained path if config is None:
config_path = os.path.join(pretrained_model_path, "config.json") config = cls.config_class.from_pretrained(
config = cls.config_class.from_pretrained(config_path) pretrained_name_or_path,
processor = AutoProcessor.from_pretrained(pretrained_model_path, use_fast=True) cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
strict=strict,
**kwargs,
)
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
if action_tokenizer_path is not None: if action_tokenizer_path is not None:
processor.action_processor = AutoProcessor.from_pretrained( processor.action_processor = AutoProcessor.from_pretrained(
action_tokenizer_path, trust_remote_code=True action_tokenizer_path, trust_remote_code=True
@@ -312,22 +304,41 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Resize token embeddings to match processor tokenizer vocabulary size # Resize token embeddings to match processor tokenizer vocabulary size
model.resize_token_embeddings(len(processor.tokenizer)) model.resize_token_embeddings(len(processor.tokenizer))
# Load model state dict from safetensors file # Try to load the model.safetensors file
safetensor_files = glob.glob( print(f"Loading model from: {pretrained_name_or_path}")
os.path.join(pretrained_model_path, "*.safetensors") try:
) from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
from safetensors.torch import load_file
sd = load_file(resolved_file)
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
print("Returning model without loading pretrained weights")
return model
state_dict = {} state_dict = {}
for file in safetensor_files: # filter normalizer statistic params
sd = load_file(file, device="cpu") del_keys = []
# filter normalizer statistic params for key in sd.keys():
del_keys = [] if "action_preprocessor.normalizer" in key:
for key in sd.keys(): del_keys.append(key)
if "action_preprocessor.normalizer" in key: for key in del_keys:
print(f"filter load model weight {key}") del sd[key]
del_keys.append(key) state_dict.update(sd)
for key in del_keys:
del sd[key]
state_dict.update(sd)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
@@ -730,7 +741,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
rope_deltas: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None,
dataset_names: Optional[str] = None,
dof_mask: Optional[torch.FloatTensor] = None, dof_mask: Optional[torch.FloatTensor] = None,
agent_pos_mask: Optional[torch.FloatTensor] = None, agent_pos_mask: Optional[torch.FloatTensor] = None,
**kwargs, **kwargs,
@@ -763,7 +773,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
rope_deltas (torch.LongTensor, optional): RoPE position deltas rope_deltas (torch.LongTensor, optional): RoPE position deltas
cache_position (torch.LongTensor, optional): Cache position indices cache_position (torch.LongTensor, optional): Cache position indices
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
dataset_names (str, optional): Names of datasets in the current batch
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens
agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data
**kwargs: Additional keyword arguments **kwargs: Additional keyword arguments
@@ -872,7 +881,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
) )
proprioception = self.action_preprocessor.proprioception_proj( proprioception = self.action_preprocessor.proprioception_proj(
proprioception, proprioception,
dataset_names,
agent_pos_mask, agent_pos_mask,
use_history=proprioception.shape[1] > 1, use_history=proprioception.shape[1] > 1,
) )
@@ -909,7 +917,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
) )
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
noisy_action_emb, flow = self.action_preprocessor( noisy_action_emb, flow = self.action_preprocessor(
action_chunk, dataset_names, dof_mask action_chunk, dof_mask
) )
mask = input_ids == self.action_token_id_set["action_token_id"] mask = input_ids == self.action_token_id_set["action_token_id"]
mask_unsqueezed = mask.unsqueeze(-1) mask_unsqueezed = mask.unsqueeze(-1)
@@ -953,18 +961,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Compute losses if labels are provided # Compute losses if labels are provided
if labels is not None: if labels is not None:
loss = 0 loss = 0
action_accuracy = 0
unique_datasets_name = list(set(dataset_names))
# Initialize per-dataset loss tracking dictionaries
channel_loss_dict = {
dataset_name: torch.tensor(0.0, device=logits.device)
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
}
channel_loss_count_dict = {
dataset_name: torch.tensor(0, device=logits.device)
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
}
# Compute standard cross-entropy loss for language modeling # Compute standard cross-entropy loss for language modeling
shift_logits = logits[..., :-1, :].contiguous() shift_logits = logits[..., :-1, :].contiguous()
@@ -982,22 +978,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
else torch.tensor(0.0, device=shift_logits.device) else torch.tensor(0.0, device=shift_logits.device)
) )
# Compute per-dataset channel losses
_cross_entropy_loss = _cross_entropy_loss.view(batch_size, seq_length - 1)
non_ignored_mask = non_ignored_mask.view(batch_size, seq_length - 1)
for dataset_name_i in unique_datasets_name:
dataset_mask = torch.tensor(
[name == dataset_name_i for name in dataset_names],
device=logits.device,
)
combined_mask = dataset_mask.unsqueeze(1) & non_ignored_mask
channel_loss_dict[dataset_name_i] = (
_cross_entropy_loss[combined_mask].sum()
if combined_mask.any()
else torch.tensor(0.0, device=shift_logits.device)
)
channel_loss_count_dict[dataset_name_i] += combined_mask.sum()
# Add cross-entropy loss to total loss if valid # Add cross-entropy loss to total loss if valid
if not torch.isnan(cross_entropy_loss): if not torch.isnan(cross_entropy_loss):
loss += cross_entropy_loss loss += cross_entropy_loss
@@ -1005,20 +985,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
with torch.no_grad(): with torch.no_grad():
cross_entropy_loss.detach() cross_entropy_loss.detach()
# Compute action token prediction accuracy
shift_logits = logits[..., :-1, :].contiguous()
action_preds = shift_logits.argmax(dim=-1)
shift_labels = labels[..., 1:].contiguous()
if self.use_fast_tokenizer:
action_mask = (
shift_labels > self.action_token_id_set["fast_action_token_list"][0]
)
correct_preds = (action_preds == shift_labels) & action_mask
action_accuracy = (
correct_preds.sum().float() / action_mask.sum().float()
)
channel_loss_dict["action_accuracy"] = action_accuracy
if action_chunk is not None: if action_chunk is not None:
action_mask = input_ids == self.action_token_id_set["action_token_id"] action_mask = input_ids == self.action_token_id_set["action_token_id"]
if action_mask.any(): if action_mask.any():
@@ -1101,7 +1067,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None,
num_inference_timesteps: Optional[int] = 10, num_inference_timesteps: Optional[int] = 10,
dataset_names: Optional[str] = None,
dof_mask: Optional[torch.FloatTensor] = None, dof_mask: Optional[torch.FloatTensor] = None,
agent_pos_mask: Optional[torch.FloatTensor] = None, agent_pos_mask: Optional[torch.FloatTensor] = None,
re_generate: bool = False, re_generate: bool = False,
@@ -1140,7 +1105,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
cache_position (torch.LongTensor, optional): Cache position indices cache_position (torch.LongTensor, optional): Cache position indices
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
num_inference_timesteps (int, optional): Number of diffusion inference steps num_inference_timesteps (int, optional): Number of diffusion inference steps
dataset_names (str, optional): Dataset names for normalization
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask dof_mask (torch.FloatTensor, optional): Degrees of freedom mask
agent_pos_mask (torch.FloatTensor, optional): Agent position mask agent_pos_mask (torch.FloatTensor, optional): Agent position mask
re_generate (bool, optional): Whether to use sampling for regeneration re_generate (bool, optional): Whether to use sampling for regeneration
@@ -1239,7 +1203,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
) )
proprio_embed = self.action_preprocessor.proprioception_proj( proprio_embed = self.action_preprocessor.proprioception_proj(
proprioception, proprioception,
dataset_names,
agent_pos_mask, agent_pos_mask,
use_history=proprioception.shape[1] > 1, use_history=proprioception.shape[1] > 1,
) )
@@ -1339,7 +1302,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"dof_mask": dof_mask, "dof_mask": dof_mask,
"agent_pos_mask": agent_pos_mask, "agent_pos_mask": agent_pos_mask,
"proprioception": proprioception, "proprioception": proprioception,
"dataset_names": dataset_names,
} }
# Generate output tokens # Generate output tokens
@@ -1545,7 +1507,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
image_grid_thw=None, image_grid_thw=None,
video_grid_thw=None, video_grid_thw=None,
second_per_grid_ts=None, second_per_grid_ts=None,
dataset_names=None,
proprioception=None, proprioception=None,
dof_mask=None, dof_mask=None,
agent_pos_mask=None, agent_pos_mask=None,
@@ -1572,7 +1533,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
image_grid_thw: Image grid dimensions image_grid_thw: Image grid dimensions
video_grid_thw: Video grid dimensions video_grid_thw: Video grid dimensions
second_per_grid_ts: Time interval per temporal grid second_per_grid_ts: Time interval per temporal grid
dataset_names: Dataset names for processing
proprioception: Proprioceptive sensor data proprioception: Proprioceptive sensor data
dof_mask: Degrees of freedom mask dof_mask: Degrees of freedom mask
agent_pos_mask: Agent position mask agent_pos_mask: Agent position mask
@@ -1677,7 +1637,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"cache_position": cache_position, "cache_position": cache_position,
"second_per_grid_ts": second_per_grid_ts, "second_per_grid_ts": second_per_grid_ts,
"proprioception": proprioception, "proprioception": proprioception,
"dataset_names": dataset_names,
"dof_mask": dof_mask, "dof_mask": dof_mask,
"agent_pos_mask": agent_pos_mask, "agent_pos_mask": agent_pos_mask,
} }
@@ -1866,150 +1825,14 @@ class WallXPolicy(PreTrainedPolicy):
config.validate_features() config.validate_features()
self.config = config self.config = config
# Initialize VLM wrapper # Initialize the wall-x model
self.model = Qwen2_5_VLMoEForAction(config) self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
self.model.to(config.device)
# Convert to bfloat16 for Flash Attention compatibility
self.model.to_bfloat16_for_selected_params()
self.reset() self.reset()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
) -> T:
"""
Load WallXPolicy from a pretrained model path.
Args:
pretrained_name_or_path: Path to pretrained model or model identifier
config: Optional configuration object
force_download: Force download even if cached
resume_download: Resume interrupted download
proxies: Proxy configuration
token: Authentication token
cache_dir: Cache directory path
local_files_only: Only use local files
revision: Model revision
strict: Strict loading of state dict
**kwargs: Additional arguments
Returns:
WallXPolicy: Loaded policy instance
"""
print(
"Loading Wall-X model for cross-embodiment robotic control.\n"
"This implementation integrates Qwen2.5-VL with flow matching for action prediction."
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Use provided config if available, otherwise load from pretrained path
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
# Initialize model without loading weights
model = cls(config, **kwargs)
# Load and remap the state dict
try:
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
original_state_dict = load_file(resolved_file)
print("✓ Loaded state dict from model.safetensors")
except Exception:
print(f"Could not load state dict: {e}")
print("Returning model without loading pretrained weights")
return model
# Filter out normalizer statistics if present
filtered_state_dict = {}
for key, value in original_state_dict.items():
if "action_preprocessor.normalizer" not in key:
filtered_state_dict[key] = value
else:
print(f"Filtered key: {key}")
# Add "model." prefix for keys that don't have it
remapped_state_dict = {}
remap_count = 0
for key, value in filtered_state_dict.items():
if not key.startswith("model."):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
else:
remapped_state_dict[key] = value
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
return model
def reset(self): def reset(self):
"""Reset action queue.""" """Reset action queue."""
self._queues = { self._queues = {
@@ -2022,88 +1845,50 @@ class WallXPolicy(PreTrainedPolicy):
def preprocess_inputs( def preprocess_inputs(
self, self,
batch: List[Dict[str, Any]], batch: Dict[str, Any],
config: Dict[str, Any],
dataload_config: Dict[str, Any],
lerobot_config: Dict[str, Any],
processor: Any,
action_tokenizer: Optional[Any] = None,
camera_keys: Optional[List[str]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Convert a batch of LeRobot dataset items to Wall-X model input format. Convert a batch of LeRobot dataset items to Wall-X model input format.
This is the batch version of convert_lerobot_to_wallx_format, processing multiple This processes a batched dictionary where tensors have batch dimension first.
samples together for efficient batched inference or training.
Args: Args:
batch: List of items from LeRobot dataset batch: Dictionary with batched tensors:
config: Model configuration dict - "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim)
dataload_config: Data loading configuration - "action": (batch_size, chunk_size, action_dim)
norm_stats: Normalization statistics - "observation.images.<key>": (batch_size, C, H, W)
lerobot_config: LeRobot config containing 'repo_id' - "task": List[str] of length batch_size
processor: Hugging Face processor for tokenization
action_tokenizer: Optional action tokenizer
camera_keys: List of camera keys in the dataset
Returns: Returns:
BatchFeature containing batched model inputs BatchFeature containing batched model inputs
Example:
>>> batch = [dataset[i] for i in range(8)]
>>> result = preprocess_inputs(
... batch=batch,
... config=config,
... dataload_config=dataload_config,
... norm_stats=norm_stats,
... lerobot_config={'repo_id': 'lerobot/aloha_mobile_cabinet'},
... processor=processor,
... )
""" """
repo_id = lerobot_config["repo_id"] use_fast_tokenizer = self.config.use_fast_tokenizer
use_fast_tokenizer = config.get("use_fast_tokenizer", False)
# Get key mappings # Get batch size from state tensor
cam_key_mapping = KEY_MAPPINGS[repo_id]["camera"] batch_size = batch[OBS_STATE].shape[0]
state_key = KEY_MAPPINGS[repo_id]["state"]
action_key = KEY_MAPPINGS[repo_id]["action"]
# Build data config
data_config = X2RDataProcessingConfig().update(
train_test_split=dataload_config.get("train_test_split", 0.95),
split_seed=dataload_config.get("split_seed", 42),
predict_action_keys=dataload_config.get("predict_action_keys", []),
obs_action_keys=dataload_config.get("obs_action_keys", []),
resolution=dataload_config.get("resolution", None),
priority_order=dataload_config.get("priority_order", None),
)
if camera_keys is None:
camera_keys = list(cam_key_mapping.keys())
# ==================== PROCESS ALL SAMPLES ==================== # ==================== PROCESS ALL SAMPLES ====================
all_image_inputs = [] all_image_inputs = []
all_texts = [] all_texts = []
all_agent_pos = []
all_actions = []
all_frame_indices = []
for data in batch: # Find image keys in batch
img_keys = [key for key in self.config.image_features if key in batch]
for i in range(batch_size):
# Vision preprocessing per sample # Vision preprocessing per sample
processed_frames = [] processed_frames = []
orig_height, orig_width = None, None orig_height, orig_width = None, None
resized_height, resized_width = None, None resized_height, resized_width = None, None
for key in camera_keys: for key in img_keys:
current_obs = data[key].clone() current_obs = batch[key][i].clone() # (C, H, W)
if current_obs.dim() == 3: if current_obs.dim() == 3:
current_obs = current_obs.permute(1, 2, 0) current_obs = current_obs.permute(1, 2, 0) # (H, W, C)
img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy()) img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy())
orig_width, orig_height = img_pil.size orig_width, orig_height = img_pil.size
cam_name = cam_key_mapping.get(key, key) target_size = RESOLUTION
target_size = data_config.resolution.get(cam_name, -1)
if target_size != -1: if target_size != -1:
if orig_width > orig_height: if orig_width > orig_height:
new_width = target_size new_width = target_size
@@ -2117,9 +1902,9 @@ class WallXPolicy(PreTrainedPolicy):
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
current_height, current_height,
current_width, current_width,
factor=data_config.image_factor, factor=IMAGE_FACTOR,
min_pixels=data_config.min_pixels, min_pixels=MIN_PIXELS,
max_pixels=data_config.max_pixels, max_pixels=MAX_PIXELS,
) )
resized_img = img_pil.resize((resized_width, resized_height)) resized_img = img_pil.resize((resized_width, resized_height))
processed_frames.append(resized_img) processed_frames.append(resized_img)
@@ -2127,33 +1912,30 @@ class WallXPolicy(PreTrainedPolicy):
all_image_inputs.append(processed_frames) all_image_inputs.append(processed_frames)
# Text preprocessing # Text preprocessing
frame_index = data["frame_index"] task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"]
instruction_info = {"instruction": data["task"]} instruction_info = {"instruction": task_text}
frame_index = batch["frame_index"][i] if "frame_index" in batch else 0
complete_text, _ = get_wallx_normal_text( complete_text, _ = get_wallx_normal_text(
instruction_info, instruction_info,
dataload_config.get("action_horizon", 33) - 1, self.config.chunk_size,
frame_index, frame_index,
data_config.priority_order, PRIORITY_ORDER,
cam_key_mapping, img_keys,
generate_subtask_ratio=data_config.generate_subtask_ratio, generate_subtask_ratio=GENERATE_SUBTASK_RATIO,
) )
text = process_grounding_points( text = process_grounding_points(
complete_text, orig_height, orig_width, resized_height, resized_width, complete_text, orig_height, orig_width, resized_height, resized_width,
data_config.model_type MODEL_TYPE
) )
all_texts.append(text) all_texts.append(text)
# Collect raw values
all_agent_pos.append(data[state_key])
all_actions.append(data[action_key])
all_frame_indices.append(frame_index)
# Stack agent_pos # ==================== PROCESS AGENT POS ====================
agent_pos = torch.stack(all_agent_pos) agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
if agent_pos.dim() == 2: if agent_pos.dim() == 2:
agent_pos = agent_pos.unsqueeze(1) agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim)
agent_pos_mask = (~torch.isnan(agent_pos)).float() agent_pos_mask = (~torch.isnan(agent_pos)).float()
agent_pos = agent_pos.nan_to_num(nan=0.0) agent_pos = agent_pos.nan_to_num(nan=0.0)
@@ -2161,15 +1943,15 @@ class WallXPolicy(PreTrainedPolicy):
pad_size = 20 - agent_pos.shape[-1] pad_size = 20 - agent_pos.shape[-1]
agent_pos = torch.cat([ agent_pos = torch.cat([
agent_pos, agent_pos,
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size) torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device)
], dim=-1) ], dim=-1)
agent_pos_mask = torch.cat([ agent_pos_mask = torch.cat([
agent_pos_mask, agent_pos_mask,
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size) torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device)
], dim=-1) ], dim=-1)
# Stack actions # ==================== PROCESS ACTIONS ====================
action = torch.stack(all_actions) action = batch[ACTION] # (batch_size, chunk_size, action_dim)
if action.dim() == 2: if action.dim() == 2:
action = action.unsqueeze(1) action = action.unsqueeze(1)
dof_mask = (~torch.isnan(action)).float() dof_mask = (~torch.isnan(action)).float()
@@ -2179,36 +1961,35 @@ class WallXPolicy(PreTrainedPolicy):
pad_size = 20 - action.shape[-1] pad_size = 20 - action.shape[-1]
action = torch.cat([ action = torch.cat([
action, action,
torch.zeros(action.shape[0], action.shape[1], pad_size) torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
], dim=-1) ], dim=-1)
dof_mask = torch.cat([ dof_mask = torch.cat([
dof_mask, dof_mask,
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size) torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
], dim=-1) ], dim=-1)
# ==================== ACTION TOKEN REPLACEMENT ==================== # ==================== ACTION TOKEN REPLACEMENT ====================
all_texts = replace_action_token( all_texts = replace_action_token(
all_texts, all_texts,
action, action,
action_tokenizer if use_fast_tokenizer else None, self.model.action_tokenizer if use_fast_tokenizer else None,
[repo_id] * len(batch),
dof_mask, dof_mask,
) )
# ==================== TOKENIZATION ==================== # ==================== TOKENIZATION ====================
inputs = preprocesser_call( inputs = preprocesser_call(
processor=processor, processor=self.model.processor,
text=all_texts, text=all_texts,
images=all_image_inputs, images=all_image_inputs,
videos=None, videos=None,
padding=True, padding=True,
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
max_length=dataload_config.get("max_length", 768), max_length=TOKENIZER_MAX_LENGTH,
) )
# ==================== ADDITIONAL INPUTS ==================== # ==================== ADDITIONAL INPUTS ====================
action_token_id = processor.tokenizer.convert_tokens_to_ids("<|action|>") action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>")
moe_token_types = inputs.input_ids == action_token_id moe_token_types = inputs.input_ids == action_token_id
inputs["proprioception"] = agent_pos inputs["proprioception"] = agent_pos
@@ -2216,11 +1997,13 @@ class WallXPolicy(PreTrainedPolicy):
inputs["action_chunk"] = action inputs["action_chunk"] = action
inputs["dof_mask"] = dof_mask inputs["dof_mask"] = dof_mask
inputs["moe_token_types"] = moe_token_types inputs["moe_token_types"] = moe_token_types
inputs["dataset_names"] = [repo_id] * len(batch) inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device)
inputs["frame_index"] = torch.stack([
torch.tensor(fi) if not isinstance(fi, torch.Tensor) else fi # Move all tensors to the correct device
for fi in all_frame_indices device = self.config.device
]) for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.to(device)
return inputs return inputs
@@ -2232,37 +2015,19 @@ class WallXPolicy(PreTrainedPolicy):
batch: Dictionary containing preprocessed inputs from preprocess_inputs() batch: Dictionary containing preprocessed inputs from preprocess_inputs()
Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw, Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw,
proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types, proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types,
dataset_names, etc. etc.
Returns: Returns:
tuple: (loss, loss_dict) tuple: (loss, loss_dict)
""" """
batch = self.preprocess_inputs( batch = self.preprocess_inputs(
batch, batch,
self.config,
self.dataload_config,
self.lerobot_config,
self.processor,
self.action_tokenizer,
self.camera_keys
) )
# Call the underlying model's forward with mode="train" # Call the underlying model's forward with mode="train"
outputs = self.model( outputs = self.model(
mode="train", **batch,
input_ids=batch.get("input_ids"), mode="train"
attention_mask=batch.get("attention_mask"),
pixel_values=batch.get("pixel_values"),
image_grid_thw=batch.get("image_grid_thw"),
pixel_values_videos=batch.get("pixel_values_videos"),
video_grid_thw=batch.get("video_grid_thw"),
proprioception=batch.get("proprioception"),
agent_pos_mask=batch.get("agent_pos_mask"),
action_chunk=batch.get("action_chunk"),
dof_mask=batch.get("dof_mask"),
moe_token_types=batch.get("moe_token_types"),
dataset_names=batch.get("dataset_names"),
labels=batch.get("labels", batch.get("input_ids")), # Use input_ids as labels if not provided
) )
# Extract losses from output # Extract losses from output
@@ -2292,16 +2057,10 @@ class WallXPolicy(PreTrainedPolicy):
batch = self.preprocess_inputs( batch = self.preprocess_inputs(
batch, batch,
self.config,
self.dataload_config,
self.lerobot_config,
self.processor,
self.action_tokenizer,
self.camera_keys
) )
if self.config.prediction_mode == "diffusion": if self.config.prediction_mode == "diffusion":
actions = self.model( output = self.model(
**batch, **batch,
action_dim=self.config.max_action_dim, action_dim=self.config.max_action_dim,
pred_horizon=self.config.chunk_size, pred_horizon=self.config.chunk_size,
@@ -2309,9 +2068,9 @@ class WallXPolicy(PreTrainedPolicy):
predict_mode="diffusion" predict_mode="diffusion"
) )
elif self.config.prediction_mode == "fast": elif self.config.prediction_mode == "fast":
actions = self.model( output = self.model(
**batch, **batch,
action_dim=self.config.action_feature.shape[0], action_dim=self.config.output_features["action"].shape[0],
pred_horizon=self.config.chunk_size, pred_horizon=self.config.chunk_size,
mode="predict", mode="predict",
predict_mode="fast" predict_mode="fast"
@@ -2319,8 +2078,12 @@ class WallXPolicy(PreTrainedPolicy):
else: else:
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented") raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
# Unpad actions # Extract action tensor from output dictionary
actions = actions[:, :, :self.config.action_feature.shape[0]] actions = output["predict_action"]
# Unpad actions to actual action dimension
action_dim = self.config.output_features["action"].shape[0]
actions = actions[:, :, :action_dim]
return actions return actions
@@ -29,13 +29,10 @@ from lerobot.processor import (
PolicyProcessorPipeline, PolicyProcessorPipeline,
ProcessorStepRegistry, ProcessorStepRegistry,
RenameObservationsProcessorStep, RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep, UnnormalizerProcessorStep,
) )
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors( def make_wall_x_pre_post_processors(
config: WallXConfig, config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -49,7 +46,6 @@ def make_wall_x_pre_post_processors(
The pre-processing pipeline prepares input data for the model by: The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations 1. Renaming features to match pretrained configurations
2. Adding a batch dimension 2. Adding a batch dimension
3. Tokenizing language task descriptions
4. Normalizing input and output features based on dataset statistics 4. Normalizing input and output features based on dataset statistics
5. Moving all data to the specified device 5. Moving all data to the specified device
@@ -65,25 +61,10 @@ def make_wall_x_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines A tuple containing the configured pre-processor and post-processor pipelines
""" """
# Try to use Qwen processor if available
try:
from transformers import AutoProcessor
tokenizer_name = config.vlm_model_name
qwen_available = True
except ImportError:
tokenizer_name = "Qwen/Qwen2-VL-2B-Instruct" # Fallback
qwen_available = False
input_steps = [ input_steps = [
RenameObservationsProcessorStep(rename_map={}), RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
WallXTaskProcessor(), # Process task description WallXTaskProcessor(), # Process task description
TokenizerProcessorStep(
tokenizer_name=tokenizer_name,
padding="max_length",
padding_side="right",
max_length=config.tokenizer_max_length,
),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
@@ -152,30 +133,3 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features return features
@ProcessorStepRegistry.register(name="wall_x_image_processor")
class WallXImageProcessor(ComplementaryDataProcessorStep):
"""
Image processor for Wall-X using Qwen-VL vision processing.
This handles image formatting according to Qwen-VL requirements.
"""
def __init__(self):
super().__init__()
try:
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
self.available = True
except ImportError:
self.available = False
def complementary_data(self, complementary_data):
# Image processing is handled by the VLM processor
return complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+29 -18
View File
@@ -33,10 +33,8 @@ from transformers import BatchFeature
from lerobot.policies.wall_x.constant import ( from lerobot.policies.wall_x.constant import (
CAMERA_NAME_MAPPING, CAMERA_NAME_MAPPING,
FREQUENCY_MAPPING,
KEY_MAPPINGS,
MULTIMODAL_DATASET_NAMES,
) )
from lerobot.utils.constants import OBS_IMAGES
@dataclass @dataclass
@@ -457,7 +455,7 @@ def get_wallx_normal_text(
action_chunk_size: int, action_chunk_size: int,
frame_idx: int, frame_idx: int,
priority_order: Optional[OrderedDict] = None, priority_order: Optional[OrderedDict] = None,
cam_mapping: Optional[Dict[str, str]] = None, img_keys: Optional[List[str]] = None,
generate_subtask_ratio: float = 0.0, generate_subtask_ratio: float = 0.0,
) -> Tuple[str, bool]: ) -> Tuple[str, bool]:
"""Construct complete multimodal prompt text for Wall-X model. """Construct complete multimodal prompt text for Wall-X model.
@@ -474,7 +472,7 @@ def get_wallx_normal_text(
action_chunk_size: Number of action tokens to generate action_chunk_size: Number of action tokens to generate
frame_idx: Current frame index frame_idx: Current frame index
priority_order: Priority order for instruction sampling priority_order: Priority order for instruction sampling
cam_mapping: Camera name mapping dictionary img_keys: List of image keys
generate_subtask_ratio: Probability of generating subtask instead of actions generate_subtask_ratio: Probability of generating subtask instead of actions
Returns: Returns:
@@ -497,10 +495,10 @@ def get_wallx_normal_text(
# User request with observation # User request with observation
user_request = f"{role_start_symbol}user\nObservation:" user_request = f"{role_start_symbol}user\nObservation:"
if cam_mapping: if img_keys:
for _, cam_name in cam_mapping.items(): img_keys = img_key_mapping(img_keys)
view_name = CAMERA_NAME_MAPPING.get(cam_name, cam_name) for key in img_keys:
user_request += f" {view_name}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}" user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
user_request += "\nInstruction:" user_request += "\nInstruction:"
# Get frame-specific instruction # Get frame-specific instruction
@@ -543,6 +541,27 @@ def get_wallx_normal_text(
complete_text = prologue + user_message + assistant_output complete_text = prologue + user_message + assistant_output
return complete_text, generate_subtask return complete_text, generate_subtask
def img_key_mapping(img_keys: List[str]) -> List[str]:
"""Map image keys to camera names.
Args:
img_keys: List of image keys
Returns:
List of camera names
"""
processed_img_keys = []
for key in img_keys:
key = key.replace(OBS_IMAGES + ".", "")
if key in CAMERA_NAME_MAPPING:
key = CAMERA_NAME_MAPPING[key]
else:
if 'view' in key:
key = key.replace('_', ' ')
else:
key = key + " view"
processed_img_keys.append(key)
return processed_img_keys
def get_action_tokens( def get_action_tokens(
normalized_actions: Union[torch.Tensor, List], action_tokenizer normalized_actions: Union[torch.Tensor, List], action_tokenizer
@@ -599,7 +618,6 @@ def replace_action_token(
text: List[str], text: List[str],
norm_action: Optional[torch.Tensor], norm_action: Optional[torch.Tensor],
action_tokenizer, action_tokenizer,
dataset_names: List[str],
dof_masks: Optional[torch.Tensor] = None, dof_masks: Optional[torch.Tensor] = None,
) -> List[str]: ) -> List[str]:
"""Replace action placeholders in text with actual action tokens. """Replace action placeholders in text with actual action tokens.
@@ -615,17 +633,10 @@ def replace_action_token(
List of text strings with action tokens replaced List of text strings with action tokens replaced
""" """
# Filter out multimodal dataset names # Filter out multimodal dataset names
dataset_names = [
name for name in dataset_names if name not in MULTIMODAL_DATASET_NAMES
]
# Get required action chunk sizes
required_chunk_sizes = [32 for name in dataset_names]
if action_tokenizer is not None and norm_action is not None: if action_tokenizer is not None and norm_action is not None:
# Extract actions based on chunk sizes and DOF masks # Extract actions based on chunk sizes and DOF masks
norm_action = [ norm_action = [
action[: required_chunk_sizes[i], dof_masks[i, 0].bool()] action[: 32, dof_masks[i, 0].bool()]
for i, action in enumerate(norm_action) for i, action in enumerate(norm_action)
] ]
+2
View File
@@ -0,0 +1,2 @@
# Wall-X policy tests
+126
View File
@@ -0,0 +1,126 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local Wall-X installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.wall_x import ( # noqa: E402
WallXConfig,
WallXPolicy,
make_wall_x_pre_post_processors, # noqa: E402
)
from lerobot.utils.random_utils import set_seed # noqa: E402
def test_policy_instantiation():
# Create config
set_seed(42)
config = WallXConfig(device='cuda')
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
),
"observation.images.face_view": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.face_view": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy
policy = WallXPolicy(config)
preprocessor, postprocessor = make_wall_x_pre_post_processors(config=config, dataset_stats=dataset_stats)
# Test forward pass with dummy data
batch_size = 1
device = config.device
batch = {
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.face_view": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size,
}
batch = preprocessor(batch)
try:
loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
try:
with torch.no_grad():
action = policy.select_action(batch)
action = postprocessor(action)
print(f"Action: {action}")
print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e:
print(f"Action prediction failed: {e}")
raise
def test_config_creation():
"""Test policy config creation through factory."""
try:
config = make_policy_config(
policy_type="wall_x",
)
print("Config created successfully through factory")
print(f" Config type: {type(config).__name__}")
except Exception as e:
print(f"Config creation failed: {e}")
raise
if __name__ == "__main__":
test_policy_instantiation()
test_config_creation()