From d10d3ef2515ebbb40ec979e17050e57f80f74314 Mon Sep 17 00:00:00 2001 From: Geoffrey19 Date: Thu, 4 Dec 2025 15:38:04 +0800 Subject: [PATCH] reduce to least config and params & pass lerobot basic test --- .../policies/wall_x/configuration_wall_x.py | 447 ++-------------- src/lerobot/policies/wall_x/constant.py | 20 +- .../policies/wall_x/modeling_wall_x.py | 493 +++++------------- .../policies/wall_x/processor_wall_x.py | 46 -- src/lerobot/policies/wall_x/utils.py | 47 +- tests/policies/wall_x/__init__.py | 2 + tests/policies/wall_x/test_wallx.py | 126 +++++ 7 files changed, 349 insertions(+), 832 deletions(-) create mode 100644 tests/policies/wall_x/__init__.py create mode 100644 tests/policies/wall_x/test_wallx.py diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index c0936f427..a4d67eb0f 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -13,13 +13,11 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("wall_x") @@ -34,48 +32,15 @@ class WallXConfig(PreTrainedConfig): This config supports multi-modal learning with vision, language, and action data. """ - # ==================== Model and Paths Configuration ==================== - # Logging - log_name: str = "wall_x_training" - log_project: str = "vla_training" - model_type: str = "wall-oss" - - # Pretrained model paths - pretrained_wallx_path: str | None = None # Path to pretrained Wall-X model - save_path: str | None = None # Path to save checkpoints - processor_path: str | None = None # Path to processor (defaults to pretrained_wallx_path) - action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode) - - # Tokenizer settings - use_fast_tokenizer: bool = False # True: train FAST, False: train Flow - - # ==================== Profiling Configuration ==================== - profile: bool = False - profile_save_path: str | None = None - profile_wait_iters: int = 10 - profile_warmup_iters: int = 5 - profile_active_iters: int = 2 - - # ==================== Training Hyperparameters ==================== - num_warmup_steps: int = 100 - num_training_steps: int = 64000000 - learning_rate: float = 5e-5 - min_lr: float = 5e-5 - num_epoch: int = 100 - gradient_accumulation_steps: int = 32 - batch_size_per_gpu: int = 8 - padding_side: str = "left" - epoch_save_interval: int = 10 - - # Training optimization - fsdp2: bool = False - torch_compile: bool = False - # ==================== Input / Output Structure ==================== n_obs_steps: int = 1 chunk_size: int = 32 # action_horizon in wall-x n_action_steps: int = 32 + # Action dimension - wall-x uses 20 + max_action_dim: int = 20 + max_state_dim: int = 20 # For proprioception + normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, @@ -84,101 +49,17 @@ class WallXConfig(PreTrainedConfig): } ) - # Action dimension - wall-x uses hardcoded 20 - max_action_dim: int = 20 - max_state_dim: int = 20 # For proprioception + # ==================== Action Prediction ==================== + # Pretrained model paths + pretrained_name_or_path: str = "x-square-robot/wall-oss-flow" - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] | None = None # wall-x uses Qwen processor + # Action prediction mode: "diffusion" or "fast" + prediction_mode: str = "diffusion" - # Tokenizer - tokenizer_max_length: int = 256 + # Tokenizer settings + use_fast_tokenizer: bool = False # True: train FAST, False: train Flow + action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode) - # ==================== Model Architecture ==================== - vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct" - load_vlm_weights: bool = True - - # Vision config - vision_config: dict = field(default_factory=lambda: { - "depth": 32, - "hidden_size": 3584, - "hidden_act": "silu", - "intermediate_size": 3420, - "num_heads": 16, - "patch_size": 14, - "spatial_merge_size": 2, - "temporal_patch_size": 2, - "window_size": 112, - "out_hidden_size": 3584, - }) - - # Language model config - hidden_size: int = 3584 # 8192 for 7B model - intermediate_size: int = 18944 # 29568 for 7B model - num_hidden_layers: int = 36 # 80 for 7B model - num_attention_heads: int = 28 # 64 for 7B model - num_key_value_heads: int = 4 # 8 for 7B model - vocab_size: int = 152064 - - # ==================== Action Prediction ==================== - # Action prediction mode: "flow" or "fast" - prediction_mode: str = "flow" - - # Flow matching parameters - noise_scheduler: dict = field(default_factory=lambda: { - "beta_alpha": 1.5, # Beta distribution concentration1 - "beta_beta": 1.0, # Beta distribution concentration0 - "s": 0.999, # Scaling factor for time - }) - - # Decoding parameters - num_inference_timesteps: int = 10 # Number of ODE solver steps - ode_solver_method: str = "euler" # ODE solver method - - # ==================== Robot Configuration ==================== - # Degrees of freedom configuration - defines action space - dof_config: dict = field(default_factory=lambda: { - "left_ee_pos": 3, - "left_ee_rot": 3, - "left_gripper": 1, - "right_ee_pos": 3, - "right_ee_rot": 3, - "right_gripper": 1, - }) - - # Proprioception configuration (typically mirrors dof_config) - agent_pos_config: dict = field(default_factory=lambda: { - "left_ee_pos": 3, - "left_ee_rot": 3, - "left_gripper": 1, - "right_ee_pos": 3, - "right_ee_rot": 3, - "right_gripper": 1, - }) - - # Customized robot configuration - enable_customized_robot_config: bool = False - customized_robot_config: dict = field(default_factory=lambda: { - "name": "", - "customized_dof_config": {}, - "customized_agent_pos_config": {}, - }) - - # Normalization statistics path - norm_stats_path: str | None = None - - # ==================== MoE Configuration ==================== - num_experts: int = 4 - attention_moe: bool = False - mlp_moe: bool = False - - # ==================== Finetuning Settings ==================== - freeze_vision_encoder: bool = True - train_expert_only: bool = False # wall-x trains more components - train_action_head: bool = True - - # Cache - use_cache: bool = True # ==================== Optimizer Presets ==================== optimizer_lr: float = 2e-5 @@ -191,44 +72,6 @@ class WallXConfig(PreTrainedConfig): scheduler_decay_steps: int = 100000 scheduler_decay_lr: float = 1e-6 - # ==================== Dataset Configuration ==================== - # Dataset-specific normalization statistics - action_statistics: dict = field(default_factory=dict) - - # Data configuration - data_config: dict = field(default_factory=lambda: { - "use_lerobot": True, - "lerobot_config": { - "repo_id": "", - "root": None, - "episodes": None, - "image_transforms": None, - "delta_timestamps": None, - "tolerance_s": 1e-4, - "revision": None, - "force_cache_sync": False, - "download_videos": True, - "video_backend": None, - }, - "action_horizon": 32, - "train_test_split": 0.95, - "obs_action_keys": [], - "predict_action_keys": [], - "resolution": { - "face_view": 256, - "left_wrist_view": 256, - "right_wrist_view": 256, - "move1_view": 256, - "move2_view": 256, - "top_view": 256, - "wall_view": 256, - "multi_modal": 256, - }, - }) - - # ==================== Resume Configuration ==================== - resume_config: dict | None = field(default_factory=lambda: None) - def __post_init__(self): super().__post_init__() @@ -239,243 +82,55 @@ class WallXConfig(PreTrainedConfig): f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." ) - if self.prediction_mode not in ["flow", "fast"]: + if self.prediction_mode not in ["diffusion", "fast"]: raise ValueError( - f"prediction_mode must be 'flow' or 'fast', got {self.prediction_mode}" - ) - - # Validate dof_config total doesn't exceed max_action_dim - total_dof = sum(self.dof_config.values()) - if total_dof > self.max_action_dim: - raise ValueError( - f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})" + f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}" ) # Sync prediction_mode with use_fast_tokenizer if self.use_fast_tokenizer: self.prediction_mode = "fast" else: - self.prediction_mode = "flow" + self.prediction_mode = "diffusion" - def get_train_config(self) -> dict: - """ - Extract the complete train_config dictionary matching the YAML training configuration format. + def validate_features(self) -> None: + """Validate and set up input/output features.""" + image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL] + if not image_features: + raise ValueError( + "Wall-X policy requires at least one visual input feature. " + "No features of type FeatureType.VISUAL found in input_features." + ) - This method constructs the full train_config from WallXConfig fields, suitable for - training scripts and Qwen2_5_VLMoEForAction.from_pretrained. - - Returns: - dict: Complete training configuration matching YAML structure. - """ - # Build customized_robot_config - if self.enable_customized_robot_config and self.customized_robot_config: - customized_robot_config = { - "name": self.customized_robot_config.get("name", ""), - "customized_dof_config": self.customized_robot_config.get( - "customized_dof_config", self.dof_config - ), - "customized_agent_pos_config": self.customized_robot_config.get( - "customized_agent_pos_config", self.agent_pos_config - ), - } + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature else: - customized_robot_config = { - "name": self.data_config.get("lerobot_config", {}).get("repo_id", ""), - "customized_dof_config": self.dof_config, - "customized_agent_pos_config": self.agent_pos_config, - } + state_shape = self.input_features["observation.state"].shape + state_dim = state_shape[0] if state_shape else 0 + if state_dim > self.max_state_dim: + raise ValueError( + f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. " + f"Either reduce state dimension or increase max_state_dim in config." + ) - train_config = { - # Model and paths configuration - "log_name": self.log_name, - "log_project": self.log_project, - "model_type": self.model_type, - "pretrained_wallx_path": self.pretrained_wallx_path, - "save_path": self.save_path, - "use_fast_tokenizer": self.use_fast_tokenizer, - "action_tokenizer_path": self.action_tokenizer_path, - - # Profiling configuration - "profile": self.profile, - "profile_save_path": self.profile_save_path, - "profile_wait_iters": self.profile_wait_iters, - "profile_warmup_iters": self.profile_warmup_iters, - "profile_active_iters": self.profile_active_iters, - - # Training hyperparameters - "num_warmup_steps": self.num_warmup_steps, - "num_training_steps": self.num_training_steps, - "learning_rate": self.learning_rate, - "min_lr": self.min_lr, - "num_epoch": self.num_epoch, - "gradient_accumulation_steps": self.gradient_accumulation_steps, - "batch_size_per_gpu": self.batch_size_per_gpu, - "padding_side": self.padding_side, - "epoch_save_interval": self.epoch_save_interval, - - # Training optimization - "FSDP2": self.fsdp2, - "torch_compile": self.torch_compile, - - # Robot configuration - "dof_config": self.dof_config, - "agent_pos_config": self.agent_pos_config, - - # Normalization stats - "norm_stats_path": self.norm_stats_path, - - # Customized robot config - "enable_customized_robot_config": self.enable_customized_robot_config, - "customized_robot_config": customized_robot_config, - - # Resume configuration - "resume": self.resume_config, - - # Data configuration - "data": self.data_config, - } - - return train_config - - def get_dataload_config(self) -> dict: - """ - Extract data loading configuration from config. - - Returns: - dict: Data loading configuration for preprocessing. - """ - return { - "action_horizon": self.data_config.get("action_horizon", self.chunk_size), - "train_test_split": self.data_config.get("train_test_split", 0.95), - "split_seed": 42, - "predict_action_keys": self.data_config.get("predict_action_keys", []), - "obs_action_keys": self.data_config.get("obs_action_keys", []), - "resolution": self.data_config.get("resolution", {}), - "priority_order": None, - "max_length": self.tokenizer_max_length, - } - - def get_lerobot_config(self) -> dict: - """ - Extract LeRobot dataset configuration. - - Returns: - dict: LeRobot dataset configuration. - """ - return self.data_config.get("lerobot_config", {}) - - @classmethod - def from_yaml_dict(cls, yaml_dict: dict) -> "WallXConfig": - """ - Create a WallXConfig from a YAML configuration dictionary. - - Args: - yaml_dict: Dictionary loaded from YAML training config file. - - Returns: - WallXConfig instance with values from YAML. - """ - config_kwargs = {} - - # Model and paths - if "log_name" in yaml_dict: - config_kwargs["log_name"] = yaml_dict["log_name"] - if "log_project" in yaml_dict: - config_kwargs["log_project"] = yaml_dict["log_project"] - if "model_type" in yaml_dict: - config_kwargs["model_type"] = yaml_dict["model_type"] - if "pretrained_wallx_path" in yaml_dict: - config_kwargs["pretrained_wallx_path"] = yaml_dict["pretrained_wallx_path"] - if "save_path" in yaml_dict: - config_kwargs["save_path"] = yaml_dict["save_path"] - if "use_fast_tokenizer" in yaml_dict: - config_kwargs["use_fast_tokenizer"] = yaml_dict["use_fast_tokenizer"] - if "action_tokenizer_path" in yaml_dict: - config_kwargs["action_tokenizer_path"] = yaml_dict["action_tokenizer_path"] - - # Profiling - if "profile" in yaml_dict: - config_kwargs["profile"] = yaml_dict["profile"] - if "profile_save_path" in yaml_dict: - config_kwargs["profile_save_path"] = yaml_dict["profile_save_path"] - if "profile_wait_iters" in yaml_dict: - config_kwargs["profile_wait_iters"] = yaml_dict["profile_wait_iters"] - if "profile_warmup_iters" in yaml_dict: - config_kwargs["profile_warmup_iters"] = yaml_dict["profile_warmup_iters"] - if "profile_active_iters" in yaml_dict: - config_kwargs["profile_active_iters"] = yaml_dict["profile_active_iters"] - - # Training hyperparameters - if "num_warmup_steps" in yaml_dict: - config_kwargs["num_warmup_steps"] = yaml_dict["num_warmup_steps"] - config_kwargs["scheduler_warmup_steps"] = yaml_dict["num_warmup_steps"] - if "num_training_steps" in yaml_dict: - config_kwargs["num_training_steps"] = yaml_dict["num_training_steps"] - config_kwargs["scheduler_decay_steps"] = yaml_dict["num_training_steps"] - if "learning_rate" in yaml_dict: - config_kwargs["learning_rate"] = yaml_dict["learning_rate"] - config_kwargs["optimizer_lr"] = yaml_dict["learning_rate"] - if "min_lr" in yaml_dict: - config_kwargs["min_lr"] = yaml_dict["min_lr"] - config_kwargs["scheduler_decay_lr"] = yaml_dict["min_lr"] - if "num_epoch" in yaml_dict: - config_kwargs["num_epoch"] = yaml_dict["num_epoch"] - if "gradient_accumulation_steps" in yaml_dict: - config_kwargs["gradient_accumulation_steps"] = yaml_dict["gradient_accumulation_steps"] - if "batch_size_per_gpu" in yaml_dict: - config_kwargs["batch_size_per_gpu"] = yaml_dict["batch_size_per_gpu"] - if "padding_side" in yaml_dict: - config_kwargs["padding_side"] = yaml_dict["padding_side"] - if "epoch_save_interval" in yaml_dict: - config_kwargs["epoch_save_interval"] = yaml_dict["epoch_save_interval"] - - # Training optimization - if "FSDP2" in yaml_dict: - config_kwargs["fsdp2"] = yaml_dict["FSDP2"] - if "torch_compile" in yaml_dict: - config_kwargs["torch_compile"] = yaml_dict["torch_compile"] - - # Robot configuration - if "dof_config" in yaml_dict: - config_kwargs["dof_config"] = yaml_dict["dof_config"] - if "agent_pos_config" in yaml_dict: - config_kwargs["agent_pos_config"] = yaml_dict["agent_pos_config"] - - # Normalization stats - if "norm_stats_path" in yaml_dict: - config_kwargs["norm_stats_path"] = yaml_dict["norm_stats_path"] - - # Customized robot config - if "enable_customized_robot_config" in yaml_dict: - config_kwargs["enable_customized_robot_config"] = yaml_dict["enable_customized_robot_config"] - if "customized_robot_config" in yaml_dict: - config_kwargs["customized_robot_config"] = yaml_dict["customized_robot_config"] - - # Resume config - if "resume" in yaml_dict: - config_kwargs["resume_config"] = yaml_dict["resume"] - - # Data configuration - if "data" in yaml_dict: - data = yaml_dict["data"] - data_config = { - "use_lerobot": data.get("use_lerobot", True), - "action_horizon": data.get("action_horizon", 32), - "train_test_split": data.get("train_test_split", 0.95), - "obs_action_keys": data.get("obs_action_keys", []), - "predict_action_keys": data.get("predict_action_keys", []), - "resolution": data.get("resolution", {}), - } - if "lerobot_config" in data: - data_config["lerobot_config"] = data["lerobot_config"] - config_kwargs["data_config"] = data_config - - # Set chunk_size from action_horizon - if "action_horizon" in data: - config_kwargs["chunk_size"] = data["action_horizon"] - config_kwargs["n_action_steps"] = data["action_horizon"] - - return cls(**config_kwargs) + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + else: + action_shape = self.output_features["action"].shape + action_dim = action_shape[0] if action_shape else 0 + if action_dim > self.max_action_dim: + raise ValueError( + f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. " + f"Either reduce action dimension or increase max_action_dim in config." + ) def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( @@ -496,7 +151,7 @@ class WallXConfig(PreTrainedConfig): @property def observation_delta_indices(self) -> list: - return [0] + return None @property def action_delta_indices(self) -> list: diff --git a/src/lerobot/policies/wall_x/constant.py b/src/lerobot/policies/wall_x/constant.py index 872302ad6..a894cef06 100644 --- a/src/lerobot/policies/wall_x/constant.py +++ b/src/lerobot/policies/wall_x/constant.py @@ -16,15 +16,9 @@ """ Wall-X Constants and Configuration Data. - -Contains dataset names, key mappings, frequency mappings, and action statistics -for cross-embodiment robotic control. """ -from pathlib import Path - -# Add wall-x repo to path if available -WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x") +from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION CAMERA_NAME_MAPPING = { "face_view": "front view", @@ -35,3 +29,15 @@ CAMERA_NAME_MAPPING = { "wall_view": "wall view", "top_view": "top view", } + +RESOLUTION = 256 + +# Parameters for preprocessing +MAX_PIXELS = 16384 * 28 * 28 +MIN_PIXELS = 4 * 28 * 28 +IMAGE_FACTOR = 28 +PRIORITY_ORDER = None +GENERATE_SUBTASK_RATIO = 0.0 +MODEL_TYPE = "qwen2_5" + +TOKENIZER_MAX_LENGTH = 768 \ No newline at end of file diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index bebba4a27..dbd147872 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -38,6 +38,7 @@ import builtins import glob import math import os +from os import PathLike import sys from collections import deque from pathlib import Path @@ -57,16 +58,8 @@ from torchdiffeq import odeint from transformers import AutoConfig, AutoProcessor from transformers.activations import ACT2FN from transformers.cache_utils import ( - Cache, - DynamicCache, - SlidingWindowCache, StaticCache, ) -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_utils import PreTrainedModel from transformers.utils import is_torchdynamo_compiling, logging from transformers import AutoProcessor, BatchFeature from qwen_vl_utils.vision_process import smart_resize @@ -80,29 +73,18 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN from lerobot.policies.wall_x.utils import * from lerobot.policies.wall_x.constant import * from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2RMSNorm, -) from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLRotaryEmbedding, - Qwen2_5_VLPreTrainedModel, Qwen2_5_VLForConditionalGeneration, ) from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( Qwen2_5_VisionTransformerPretrainedModel, - Qwen2_5_VLDecoderLayer_with_MoE, Qwen2_5_VLACausalLMOutputWithPast, Qwen2_5_VLMoEModel, ) logger = logging.get_logger(__name__) -# Add wall-x repo to path if available -WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x") -if WALL_X_PATH.exists(): - sys.path.insert(0, str(WALL_X_PATH)) - class SinusoidalPosEmb(nn.Module): """Sinusoidal positional embedding for diffusion timesteps.""" @@ -129,7 +111,7 @@ class ActionHead(nn.Module): for action sequence prediction. """ - def __init__(self, config: WallXConfig): + def __init__(self, config): super().__init__() self.config = config @@ -138,10 +120,9 @@ class ActionHead(nn.Module): self.hidden_size = config.hidden_size # Beta distribution for noise scheduling - noise_config = config.noise_scheduler - self.beta_alpha = noise_config.get("beta_alpha", 1.5) - self.beta_beta = noise_config.get("beta_beta", 1.0) - self.s = noise_config.get("s", 0.999) + self.beta_alpha = 1.5 + self.beta_beta = 1.0 + self.s = 0.999 # Sinusoidal timestep embedding self.time_embed = SinusoidalPosEmb(config.hidden_size) @@ -277,12 +258,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): @classmethod def from_pretrained( cls, - pretrained_model_path, - train_config, - config_path=None, - processor_path=None, + pretrained_name_or_path, + config=None, action_tokenizer_path=None, - **kwargs, + cache_dir: str | PathLike | None = None, + force_download: bool = False, + local_files_only: bool = False, + token: str | bool | None = None, + revision: str = "main", + strict: bool = False, + **kwargs: Any ): """ Load model from pretrained model path. @@ -290,17 +275,24 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): Args: pretrained_model_path (str): Model directory path containing model.safetensors file config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path - processor_path (str, optional): Processor path, if None will load from default config action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config **kwargs: Additional arguments Returns: Qwen2_5_VLMoEForAction: Loaded model instance """ - # Load model components from pretrained path - config_path = os.path.join(pretrained_model_path, "config.json") - config = cls.config_class.from_pretrained(config_path) - processor = AutoProcessor.from_pretrained(pretrained_model_path, use_fast=True) + if config is None: + config = cls.config_class.from_pretrained( + pretrained_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + strict=strict, + **kwargs, + ) + processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True) if action_tokenizer_path is not None: processor.action_processor = AutoProcessor.from_pretrained( action_tokenizer_path, trust_remote_code=True @@ -312,22 +304,41 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): # Resize token embeddings to match processor tokenizer vocabulary size model.resize_token_embeddings(len(processor.tokenizer)) - # Load model state dict from safetensors file - safetensor_files = glob.glob( - os.path.join(pretrained_model_path, "*.safetensors") - ) + # Try to load the model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + sd = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + state_dict = {} - for file in safetensor_files: - sd = load_file(file, device="cpu") - # filter normalizer statistic params - del_keys = [] - for key in sd.keys(): - if "action_preprocessor.normalizer" in key: - print(f"filter load model weight {key}") - del_keys.append(key) - for key in del_keys: - del sd[key] - state_dict.update(sd) + # filter normalizer statistic params + del_keys = [] + for key in sd.keys(): + if "action_preprocessor.normalizer" in key: + del_keys.append(key) + for key in del_keys: + del sd[key] + state_dict.update(sd) model.load_state_dict(state_dict, strict=False) @@ -730,7 +741,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, - dataset_names: Optional[str] = None, dof_mask: Optional[torch.FloatTensor] = None, agent_pos_mask: Optional[torch.FloatTensor] = None, **kwargs, @@ -763,7 +773,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): rope_deltas (torch.LongTensor, optional): RoPE position deltas cache_position (torch.LongTensor, optional): Cache position indices second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid - dataset_names (str, optional): Names of datasets in the current batch dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data **kwargs: Additional keyword arguments @@ -872,7 +881,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): ) proprioception = self.action_preprocessor.proprioception_proj( proprioception, - dataset_names, agent_pos_mask, use_history=proprioception.shape[1] > 1, ) @@ -909,7 +917,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): ) dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) noisy_action_emb, flow = self.action_preprocessor( - action_chunk, dataset_names, dof_mask + action_chunk, dof_mask ) mask = input_ids == self.action_token_id_set["action_token_id"] mask_unsqueezed = mask.unsqueeze(-1) @@ -953,18 +961,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): # Compute losses if labels are provided if labels is not None: loss = 0 - action_accuracy = 0 - unique_datasets_name = list(set(dataset_names)) - - # Initialize per-dataset loss tracking dictionaries - channel_loss_dict = { - dataset_name: torch.tensor(0.0, device=logits.device) - for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES - } - channel_loss_count_dict = { - dataset_name: torch.tensor(0, device=logits.device) - for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES - } # Compute standard cross-entropy loss for language modeling shift_logits = logits[..., :-1, :].contiguous() @@ -982,22 +978,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): else torch.tensor(0.0, device=shift_logits.device) ) - # Compute per-dataset channel losses - _cross_entropy_loss = _cross_entropy_loss.view(batch_size, seq_length - 1) - non_ignored_mask = non_ignored_mask.view(batch_size, seq_length - 1) - for dataset_name_i in unique_datasets_name: - dataset_mask = torch.tensor( - [name == dataset_name_i for name in dataset_names], - device=logits.device, - ) - combined_mask = dataset_mask.unsqueeze(1) & non_ignored_mask - channel_loss_dict[dataset_name_i] = ( - _cross_entropy_loss[combined_mask].sum() - if combined_mask.any() - else torch.tensor(0.0, device=shift_logits.device) - ) - channel_loss_count_dict[dataset_name_i] += combined_mask.sum() - # Add cross-entropy loss to total loss if valid if not torch.isnan(cross_entropy_loss): loss += cross_entropy_loss @@ -1005,20 +985,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): with torch.no_grad(): cross_entropy_loss.detach() - # Compute action token prediction accuracy - shift_logits = logits[..., :-1, :].contiguous() - action_preds = shift_logits.argmax(dim=-1) - shift_labels = labels[..., 1:].contiguous() - if self.use_fast_tokenizer: - action_mask = ( - shift_labels > self.action_token_id_set["fast_action_token_list"][0] - ) - correct_preds = (action_preds == shift_labels) & action_mask - action_accuracy = ( - correct_preds.sum().float() / action_mask.sum().float() - ) - channel_loss_dict["action_accuracy"] = action_accuracy - if action_chunk is not None: action_mask = input_ids == self.action_token_id_set["action_token_id"] if action_mask.any(): @@ -1101,7 +1067,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, num_inference_timesteps: Optional[int] = 10, - dataset_names: Optional[str] = None, dof_mask: Optional[torch.FloatTensor] = None, agent_pos_mask: Optional[torch.FloatTensor] = None, re_generate: bool = False, @@ -1140,7 +1105,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): cache_position (torch.LongTensor, optional): Cache position indices second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid num_inference_timesteps (int, optional): Number of diffusion inference steps - dataset_names (str, optional): Dataset names for normalization dof_mask (torch.FloatTensor, optional): Degrees of freedom mask agent_pos_mask (torch.FloatTensor, optional): Agent position mask re_generate (bool, optional): Whether to use sampling for regeneration @@ -1239,7 +1203,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): ) proprio_embed = self.action_preprocessor.proprioception_proj( proprioception, - dataset_names, agent_pos_mask, use_history=proprioception.shape[1] > 1, ) @@ -1339,7 +1302,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): "dof_mask": dof_mask, "agent_pos_mask": agent_pos_mask, "proprioception": proprioception, - "dataset_names": dataset_names, } # Generate output tokens @@ -1545,7 +1507,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): image_grid_thw=None, video_grid_thw=None, second_per_grid_ts=None, - dataset_names=None, proprioception=None, dof_mask=None, agent_pos_mask=None, @@ -1572,7 +1533,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): image_grid_thw: Image grid dimensions video_grid_thw: Video grid dimensions second_per_grid_ts: Time interval per temporal grid - dataset_names: Dataset names for processing proprioception: Proprioceptive sensor data dof_mask: Degrees of freedom mask agent_pos_mask: Agent position mask @@ -1677,7 +1637,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): "cache_position": cache_position, "second_per_grid_ts": second_per_grid_ts, "proprioception": proprioception, - "dataset_names": dataset_names, "dof_mask": dof_mask, "agent_pos_mask": agent_pos_mask, } @@ -1866,150 +1825,14 @@ class WallXPolicy(PreTrainedPolicy): config.validate_features() self.config = config - # Initialize VLM wrapper - self.model = Qwen2_5_VLMoEForAction(config) + # Initialize the wall-x model + self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path) + self.model.to(config.device) + # Convert to bfloat16 for Flash Attention compatibility + self.model.to_bfloat16_for_selected_params() self.reset() - @classmethod - def from_pretrained( - cls: builtins.type[T], - pretrained_name_or_path: str | Path, - *, - config: PreTrainedConfig | None = None, - force_download: bool = False, - resume_download: bool | None = None, - proxies: dict | None = None, - token: str | bool | None = None, - cache_dir: str | Path | None = None, - local_files_only: bool = False, - revision: str | None = None, - strict: bool = False, - **kwargs, - ) -> T: - """ - Load WallXPolicy from a pretrained model path. - - Args: - pretrained_name_or_path: Path to pretrained model or model identifier - config: Optional configuration object - force_download: Force download even if cached - resume_download: Resume interrupted download - proxies: Proxy configuration - token: Authentication token - cache_dir: Cache directory path - local_files_only: Only use local files - revision: Model revision - strict: Strict loading of state dict - **kwargs: Additional arguments - - Returns: - WallXPolicy: Loaded policy instance - """ - print( - "Loading Wall-X model for cross-embodiment robotic control.\n" - "This implementation integrates Qwen2.5-VL with flow matching for action prediction." - ) - - if pretrained_name_or_path is None: - raise ValueError("pretrained_name_or_path is required") - - # Use provided config if available, otherwise load from pretrained path - if config is None: - config = PreTrainedConfig.from_pretrained( - pretrained_name_or_path=pretrained_name_or_path, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - **kwargs, - ) - - # Initialize model without loading weights - model = cls(config, **kwargs) - - # Load and remap the state dict - try: - print(f"Loading model from: {pretrained_name_or_path}") - try: - from transformers.utils import cached_file - - # Try safetensors first - resolved_file = cached_file( - pretrained_name_or_path, - "model.safetensors", - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - ) - original_state_dict = load_file(resolved_file) - print("✓ Loaded state dict from model.safetensors") - except Exception: - print(f"Could not load state dict: {e}") - print("Returning model without loading pretrained weights") - return model - - # Filter out normalizer statistics if present - filtered_state_dict = {} - for key, value in original_state_dict.items(): - if "action_preprocessor.normalizer" not in key: - filtered_state_dict[key] = value - else: - print(f"Filtered key: {key}") - - # Add "model." prefix for keys that don't have it - remapped_state_dict = {} - remap_count = 0 - - for key, value in filtered_state_dict.items(): - if not key.startswith("model."): - new_key = f"model.{key}" - remapped_state_dict[new_key] = value - remap_count += 1 - else: - remapped_state_dict[key] = value - - if remap_count > 0: - print(f"Remapped {remap_count} state dict keys") - - # Load the remapped state dict into the model - missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) - - if missing_keys: - print(f"Missing keys when loading state dict: {len(missing_keys)} keys") - if len(missing_keys) <= 5: - for key in missing_keys: - print(f" - {key}") - else: - for key in missing_keys[:5]: - print(f" - {key}") - print(f" ... and {len(missing_keys) - 5} more") - - if unexpected_keys: - print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") - if len(unexpected_keys) <= 5: - for key in unexpected_keys: - print(f" - {key}") - else: - for key in unexpected_keys[:5]: - print(f" - {key}") - print(f" ... and {len(unexpected_keys) - 5} more") - - if not missing_keys and not unexpected_keys: - print("All keys loaded successfully!") - - except Exception as e: - print(f"Warning: Could not load state dict: {e}") - - return model - def reset(self): """Reset action queue.""" self._queues = { @@ -2022,88 +1845,50 @@ class WallXPolicy(PreTrainedPolicy): def preprocess_inputs( self, - batch: List[Dict[str, Any]], - config: Dict[str, Any], - dataload_config: Dict[str, Any], - lerobot_config: Dict[str, Any], - processor: Any, - action_tokenizer: Optional[Any] = None, - camera_keys: Optional[List[str]] = None, + batch: Dict[str, Any], ) -> BatchFeature: """ Convert a batch of LeRobot dataset items to Wall-X model input format. - This is the batch version of convert_lerobot_to_wallx_format, processing multiple - samples together for efficient batched inference or training. + This processes a batched dictionary where tensors have batch dimension first. Args: - batch: List of items from LeRobot dataset - config: Model configuration dict - dataload_config: Data loading configuration - norm_stats: Normalization statistics - lerobot_config: LeRobot config containing 'repo_id' - processor: Hugging Face processor for tokenization - action_tokenizer: Optional action tokenizer - camera_keys: List of camera keys in the dataset + batch: Dictionary with batched tensors: + - "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim) + - "action": (batch_size, chunk_size, action_dim) + - "observation.images.": (batch_size, C, H, W) + - "task": List[str] of length batch_size Returns: BatchFeature containing batched model inputs - - Example: - >>> batch = [dataset[i] for i in range(8)] - >>> result = preprocess_inputs( - ... batch=batch, - ... config=config, - ... dataload_config=dataload_config, - ... norm_stats=norm_stats, - ... lerobot_config={'repo_id': 'lerobot/aloha_mobile_cabinet'}, - ... processor=processor, - ... ) """ - repo_id = lerobot_config["repo_id"] - use_fast_tokenizer = config.get("use_fast_tokenizer", False) + use_fast_tokenizer = self.config.use_fast_tokenizer - # Get key mappings - cam_key_mapping = KEY_MAPPINGS[repo_id]["camera"] - state_key = KEY_MAPPINGS[repo_id]["state"] - action_key = KEY_MAPPINGS[repo_id]["action"] - - # Build data config - data_config = X2RDataProcessingConfig().update( - train_test_split=dataload_config.get("train_test_split", 0.95), - split_seed=dataload_config.get("split_seed", 42), - predict_action_keys=dataload_config.get("predict_action_keys", []), - obs_action_keys=dataload_config.get("obs_action_keys", []), - resolution=dataload_config.get("resolution", None), - priority_order=dataload_config.get("priority_order", None), - ) - - if camera_keys is None: - camera_keys = list(cam_key_mapping.keys()) + # Get batch size from state tensor + batch_size = batch[OBS_STATE].shape[0] # ==================== PROCESS ALL SAMPLES ==================== all_image_inputs = [] all_texts = [] - all_agent_pos = [] - all_actions = [] - all_frame_indices = [] - for data in batch: + # Find image keys in batch + img_keys = [key for key in self.config.image_features if key in batch] + + for i in range(batch_size): # Vision preprocessing per sample processed_frames = [] orig_height, orig_width = None, None resized_height, resized_width = None, None - for key in camera_keys: - current_obs = data[key].clone() + for key in img_keys: + current_obs = batch[key][i].clone() # (C, H, W) if current_obs.dim() == 3: - current_obs = current_obs.permute(1, 2, 0) + current_obs = current_obs.permute(1, 2, 0) # (H, W, C) img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy()) orig_width, orig_height = img_pil.size - cam_name = cam_key_mapping.get(key, key) - target_size = data_config.resolution.get(cam_name, -1) + target_size = RESOLUTION if target_size != -1: if orig_width > orig_height: new_width = target_size @@ -2117,9 +1902,9 @@ class WallXPolicy(PreTrainedPolicy): resized_height, resized_width = smart_resize( current_height, current_width, - factor=data_config.image_factor, - min_pixels=data_config.min_pixels, - max_pixels=data_config.max_pixels, + factor=IMAGE_FACTOR, + min_pixels=MIN_PIXELS, + max_pixels=MAX_PIXELS, ) resized_img = img_pil.resize((resized_width, resized_height)) processed_frames.append(resized_img) @@ -2127,33 +1912,30 @@ class WallXPolicy(PreTrainedPolicy): all_image_inputs.append(processed_frames) # Text preprocessing - frame_index = data["frame_index"] - instruction_info = {"instruction": data["task"]} + task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"] + instruction_info = {"instruction": task_text} + frame_index = batch["frame_index"][i] if "frame_index" in batch else 0 complete_text, _ = get_wallx_normal_text( instruction_info, - dataload_config.get("action_horizon", 33) - 1, + self.config.chunk_size, frame_index, - data_config.priority_order, - cam_key_mapping, - generate_subtask_ratio=data_config.generate_subtask_ratio, + PRIORITY_ORDER, + img_keys, + generate_subtask_ratio=GENERATE_SUBTASK_RATIO, ) text = process_grounding_points( complete_text, orig_height, orig_width, resized_height, resized_width, - data_config.model_type + MODEL_TYPE ) all_texts.append(text) - - # Collect raw values - all_agent_pos.append(data[state_key]) - all_actions.append(data[action_key]) - all_frame_indices.append(frame_index) + - # Stack agent_pos - agent_pos = torch.stack(all_agent_pos) + # ==================== PROCESS AGENT POS ==================== + agent_pos = batch[OBS_STATE] # (batch_size, state_dim) if agent_pos.dim() == 2: - agent_pos = agent_pos.unsqueeze(1) + agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim) agent_pos_mask = (~torch.isnan(agent_pos)).float() agent_pos = agent_pos.nan_to_num(nan=0.0) @@ -2161,15 +1943,15 @@ class WallXPolicy(PreTrainedPolicy): pad_size = 20 - agent_pos.shape[-1] agent_pos = torch.cat([ agent_pos, - torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size) + torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device) ], dim=-1) agent_pos_mask = torch.cat([ agent_pos_mask, - torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size) + torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device) ], dim=-1) - # Stack actions - action = torch.stack(all_actions) + # ==================== PROCESS ACTIONS ==================== + action = batch[ACTION] # (batch_size, chunk_size, action_dim) if action.dim() == 2: action = action.unsqueeze(1) dof_mask = (~torch.isnan(action)).float() @@ -2179,36 +1961,35 @@ class WallXPolicy(PreTrainedPolicy): pad_size = 20 - action.shape[-1] action = torch.cat([ action, - torch.zeros(action.shape[0], action.shape[1], pad_size) + torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device) ], dim=-1) dof_mask = torch.cat([ dof_mask, - torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size) + torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device) ], dim=-1) # ==================== ACTION TOKEN REPLACEMENT ==================== all_texts = replace_action_token( all_texts, action, - action_tokenizer if use_fast_tokenizer else None, - [repo_id] * len(batch), + self.model.action_tokenizer if use_fast_tokenizer else None, dof_mask, ) # ==================== TOKENIZATION ==================== inputs = preprocesser_call( - processor=processor, + processor=self.model.processor, text=all_texts, images=all_image_inputs, videos=None, padding=True, truncation=True, return_tensors="pt", - max_length=dataload_config.get("max_length", 768), + max_length=TOKENIZER_MAX_LENGTH, ) # ==================== ADDITIONAL INPUTS ==================== - action_token_id = processor.tokenizer.convert_tokens_to_ids("<|action|>") + action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>") moe_token_types = inputs.input_ids == action_token_id inputs["proprioception"] = agent_pos @@ -2216,11 +1997,13 @@ class WallXPolicy(PreTrainedPolicy): inputs["action_chunk"] = action inputs["dof_mask"] = dof_mask inputs["moe_token_types"] = moe_token_types - inputs["dataset_names"] = [repo_id] * len(batch) - inputs["frame_index"] = torch.stack([ - torch.tensor(fi) if not isinstance(fi, torch.Tensor) else fi - for fi in all_frame_indices - ]) + inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device) + + # Move all tensors to the correct device + device = self.config.device + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.to(device) return inputs @@ -2232,37 +2015,19 @@ class WallXPolicy(PreTrainedPolicy): batch: Dictionary containing preprocessed inputs from preprocess_inputs() Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw, proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types, - dataset_names, etc. + etc. Returns: tuple: (loss, loss_dict) """ batch = self.preprocess_inputs( batch, - self.config, - self.dataload_config, - self.lerobot_config, - self.processor, - self.action_tokenizer, - self.camera_keys ) # Call the underlying model's forward with mode="train" outputs = self.model( - mode="train", - input_ids=batch.get("input_ids"), - attention_mask=batch.get("attention_mask"), - pixel_values=batch.get("pixel_values"), - image_grid_thw=batch.get("image_grid_thw"), - pixel_values_videos=batch.get("pixel_values_videos"), - video_grid_thw=batch.get("video_grid_thw"), - proprioception=batch.get("proprioception"), - agent_pos_mask=batch.get("agent_pos_mask"), - action_chunk=batch.get("action_chunk"), - dof_mask=batch.get("dof_mask"), - moe_token_types=batch.get("moe_token_types"), - dataset_names=batch.get("dataset_names"), - labels=batch.get("labels", batch.get("input_ids")), # Use input_ids as labels if not provided + **batch, + mode="train" ) # Extract losses from output @@ -2292,16 +2057,10 @@ class WallXPolicy(PreTrainedPolicy): batch = self.preprocess_inputs( batch, - self.config, - self.dataload_config, - self.lerobot_config, - self.processor, - self.action_tokenizer, - self.camera_keys ) if self.config.prediction_mode == "diffusion": - actions = self.model( + output = self.model( **batch, action_dim=self.config.max_action_dim, pred_horizon=self.config.chunk_size, @@ -2309,9 +2068,9 @@ class WallXPolicy(PreTrainedPolicy): predict_mode="diffusion" ) elif self.config.prediction_mode == "fast": - actions = self.model( + output = self.model( **batch, - action_dim=self.config.action_feature.shape[0], + action_dim=self.config.output_features["action"].shape[0], pred_horizon=self.config.chunk_size, mode="predict", predict_mode="fast" @@ -2319,8 +2078,12 @@ class WallXPolicy(PreTrainedPolicy): else: raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented") - # Unpad actions - actions = actions[:, :, :self.config.action_feature.shape[0]] + # Extract action tensor from output dictionary + actions = output["predict_action"] + + # Unpad actions to actual action dimension + action_dim = self.config.output_features["action"].shape[0] + actions = actions[:, :, :action_dim] return actions diff --git a/src/lerobot/policies/wall_x/processor_wall_x.py b/src/lerobot/policies/wall_x/processor_wall_x.py index 8becd2013..d8ad402ed 100644 --- a/src/lerobot/policies/wall_x/processor_wall_x.py +++ b/src/lerobot/policies/wall_x/processor_wall_x.py @@ -29,13 +29,10 @@ from lerobot.processor import ( PolicyProcessorPipeline, ProcessorStepRegistry, RenameObservationsProcessorStep, - TokenizerProcessorStep, UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME - - def make_wall_x_pre_post_processors( config: WallXConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, @@ -49,7 +46,6 @@ def make_wall_x_pre_post_processors( The pre-processing pipeline prepares input data for the model by: 1. Renaming features to match pretrained configurations 2. Adding a batch dimension - 3. Tokenizing language task descriptions 4. Normalizing input and output features based on dataset statistics 5. Moving all data to the specified device @@ -65,25 +61,10 @@ def make_wall_x_pre_post_processors( A tuple containing the configured pre-processor and post-processor pipelines """ - # Try to use Qwen processor if available - try: - from transformers import AutoProcessor - tokenizer_name = config.vlm_model_name - qwen_available = True - except ImportError: - tokenizer_name = "Qwen/Qwen2-VL-2B-Instruct" # Fallback - qwen_available = False - input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), WallXTaskProcessor(), # Process task description - TokenizerProcessorStep( - tokenizer_name=tokenizer_name, - padding="max_length", - padding_side="right", - max_length=config.tokenizer_max_length, - ), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, @@ -152,30 +133,3 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep): self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return features - - -@ProcessorStepRegistry.register(name="wall_x_image_processor") -class WallXImageProcessor(ComplementaryDataProcessorStep): - """ - Image processor for Wall-X using Qwen-VL vision processing. - - This handles image formatting according to Qwen-VL requirements. - """ - - def __init__(self): - super().__init__() - try: - from transformers import AutoProcessor - self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") - self.available = True - except ImportError: - self.available = False - - def complementary_data(self, complementary_data): - # Image processing is handled by the VLM processor - return complementary_data - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - return features diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py index 3a908317d..10d2ab6c1 100644 --- a/src/lerobot/policies/wall_x/utils.py +++ b/src/lerobot/policies/wall_x/utils.py @@ -33,10 +33,8 @@ from transformers import BatchFeature from lerobot.policies.wall_x.constant import ( CAMERA_NAME_MAPPING, - FREQUENCY_MAPPING, - KEY_MAPPINGS, - MULTIMODAL_DATASET_NAMES, ) +from lerobot.utils.constants import OBS_IMAGES @dataclass @@ -457,7 +455,7 @@ def get_wallx_normal_text( action_chunk_size: int, frame_idx: int, priority_order: Optional[OrderedDict] = None, - cam_mapping: Optional[Dict[str, str]] = None, + img_keys: Optional[List[str]] = None, generate_subtask_ratio: float = 0.0, ) -> Tuple[str, bool]: """Construct complete multimodal prompt text for Wall-X model. @@ -474,7 +472,7 @@ def get_wallx_normal_text( action_chunk_size: Number of action tokens to generate frame_idx: Current frame index priority_order: Priority order for instruction sampling - cam_mapping: Camera name mapping dictionary + img_keys: List of image keys generate_subtask_ratio: Probability of generating subtask instead of actions Returns: @@ -497,10 +495,10 @@ def get_wallx_normal_text( # User request with observation user_request = f"{role_start_symbol}user\nObservation:" - if cam_mapping: - for _, cam_name in cam_mapping.items(): - view_name = CAMERA_NAME_MAPPING.get(cam_name, cam_name) - user_request += f" {view_name}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}" + if img_keys: + img_keys = img_key_mapping(img_keys) + for key in img_keys: + user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}" user_request += "\nInstruction:" # Get frame-specific instruction @@ -543,6 +541,27 @@ def get_wallx_normal_text( complete_text = prologue + user_message + assistant_output return complete_text, generate_subtask +def img_key_mapping(img_keys: List[str]) -> List[str]: + """Map image keys to camera names. + + Args: + img_keys: List of image keys + + Returns: + List of camera names + """ + processed_img_keys = [] + for key in img_keys: + key = key.replace(OBS_IMAGES + ".", "") + if key in CAMERA_NAME_MAPPING: + key = CAMERA_NAME_MAPPING[key] + else: + if 'view' in key: + key = key.replace('_', ' ') + else: + key = key + " view" + processed_img_keys.append(key) + return processed_img_keys def get_action_tokens( normalized_actions: Union[torch.Tensor, List], action_tokenizer @@ -599,7 +618,6 @@ def replace_action_token( text: List[str], norm_action: Optional[torch.Tensor], action_tokenizer, - dataset_names: List[str], dof_masks: Optional[torch.Tensor] = None, ) -> List[str]: """Replace action placeholders in text with actual action tokens. @@ -615,17 +633,10 @@ def replace_action_token( List of text strings with action tokens replaced """ # Filter out multimodal dataset names - dataset_names = [ - name for name in dataset_names if name not in MULTIMODAL_DATASET_NAMES - ] - - # Get required action chunk sizes - required_chunk_sizes = [32 for name in dataset_names] - if action_tokenizer is not None and norm_action is not None: # Extract actions based on chunk sizes and DOF masks norm_action = [ - action[: required_chunk_sizes[i], dof_masks[i, 0].bool()] + action[: 32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action) ] diff --git a/tests/policies/wall_x/__init__.py b/tests/policies/wall_x/__init__.py new file mode 100644 index 000000000..7f5f042a0 --- /dev/null +++ b/tests/policies/wall_x/__init__.py @@ -0,0 +1,2 @@ +# Wall-X policy tests + diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py new file mode 100644 index 000000000..ac8ee59da --- /dev/null +++ b/tests/policies/wall_x/test_wallx.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!""" + +import os + +import pytest +import torch + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local Wall-X installation and is not meant for CI", +) + +from lerobot.policies.factory import make_policy_config # noqa: E402 +from lerobot.policies.wall_x import ( # noqa: E402 + WallXConfig, + WallXPolicy, + make_wall_x_pre_post_processors, # noqa: E402 +) +from lerobot.utils.random_utils import set_seed # noqa: E402 + +def test_policy_instantiation(): + # Create config + set_seed(42) + config = WallXConfig(device='cuda') + + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(7,), + ), + "observation.images.face_view": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + "observation.images.face_view": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + } + + # Instantiate policy + policy = WallXPolicy(config) + preprocessor, postprocessor = make_wall_x_pre_post_processors(config=config, dataset_stats=dataset_stats) + # Test forward pass with dummy data + batch_size = 1 + device = config.device + batch = { + "observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), + "observation.images.face_view": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) + try: + loss, loss_dict = policy.forward(batch) + print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") + except Exception as e: + print(f"Forward pass failed: {e}") + raise + + try: + with torch.no_grad(): + action = policy.select_action(batch) + action = postprocessor(action) + print(f"Action: {action}") + print(f"Action prediction successful. Action shape: {action.shape}") + except Exception as e: + print(f"Action prediction failed: {e}") + raise + +def test_config_creation(): + """Test policy config creation through factory.""" + try: + config = make_policy_config( + policy_type="wall_x", + ) + print("Config created successfully through factory") + print(f" Config type: {type(config).__name__}") + except Exception as e: + print(f"Config creation failed: {e}") + raise + +if __name__ == "__main__": + test_policy_instantiation() + test_config_creation() \ No newline at end of file