mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
reduce to least config and params & pass lerobot basic test
This commit is contained in:
@@ -13,13 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("wall_x")
|
||||
@@ -34,48 +32,15 @@ class WallXConfig(PreTrainedConfig):
|
||||
This config supports multi-modal learning with vision, language, and action data.
|
||||
"""
|
||||
|
||||
# ==================== Model and Paths Configuration ====================
|
||||
# Logging
|
||||
log_name: str = "wall_x_training"
|
||||
log_project: str = "vla_training"
|
||||
model_type: str = "wall-oss"
|
||||
|
||||
# Pretrained model paths
|
||||
pretrained_wallx_path: str | None = None # Path to pretrained Wall-X model
|
||||
save_path: str | None = None # Path to save checkpoints
|
||||
processor_path: str | None = None # Path to processor (defaults to pretrained_wallx_path)
|
||||
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
|
||||
|
||||
# Tokenizer settings
|
||||
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
|
||||
|
||||
# ==================== Profiling Configuration ====================
|
||||
profile: bool = False
|
||||
profile_save_path: str | None = None
|
||||
profile_wait_iters: int = 10
|
||||
profile_warmup_iters: int = 5
|
||||
profile_active_iters: int = 2
|
||||
|
||||
# ==================== Training Hyperparameters ====================
|
||||
num_warmup_steps: int = 100
|
||||
num_training_steps: int = 64000000
|
||||
learning_rate: float = 5e-5
|
||||
min_lr: float = 5e-5
|
||||
num_epoch: int = 100
|
||||
gradient_accumulation_steps: int = 32
|
||||
batch_size_per_gpu: int = 8
|
||||
padding_side: str = "left"
|
||||
epoch_save_interval: int = 10
|
||||
|
||||
# Training optimization
|
||||
fsdp2: bool = False
|
||||
torch_compile: bool = False
|
||||
|
||||
# ==================== Input / Output Structure ====================
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32 # action_horizon in wall-x
|
||||
n_action_steps: int = 32
|
||||
|
||||
# Action dimension - wall-x uses 20
|
||||
max_action_dim: int = 20
|
||||
max_state_dim: int = 20 # For proprioception
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
@@ -84,101 +49,17 @@ class WallXConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Action dimension - wall-x uses hardcoded 20
|
||||
max_action_dim: int = 20
|
||||
max_state_dim: int = 20 # For proprioception
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] | None = None # wall-x uses Qwen processor
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 256
|
||||
|
||||
# ==================== Model Architecture ====================
|
||||
vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
load_vlm_weights: bool = True
|
||||
|
||||
# Vision config
|
||||
vision_config: dict = field(default_factory=lambda: {
|
||||
"depth": 32,
|
||||
"hidden_size": 3584,
|
||||
"hidden_act": "silu",
|
||||
"intermediate_size": 3420,
|
||||
"num_heads": 16,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
"window_size": 112,
|
||||
"out_hidden_size": 3584,
|
||||
})
|
||||
|
||||
# Language model config
|
||||
hidden_size: int = 3584 # 8192 for 7B model
|
||||
intermediate_size: int = 18944 # 29568 for 7B model
|
||||
num_hidden_layers: int = 36 # 80 for 7B model
|
||||
num_attention_heads: int = 28 # 64 for 7B model
|
||||
num_key_value_heads: int = 4 # 8 for 7B model
|
||||
vocab_size: int = 152064
|
||||
|
||||
# ==================== Action Prediction ====================
|
||||
# Action prediction mode: "flow" or "fast"
|
||||
prediction_mode: str = "flow"
|
||||
# Pretrained model paths
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
# Flow matching parameters
|
||||
noise_scheduler: dict = field(default_factory=lambda: {
|
||||
"beta_alpha": 1.5, # Beta distribution concentration1
|
||||
"beta_beta": 1.0, # Beta distribution concentration0
|
||||
"s": 0.999, # Scaling factor for time
|
||||
})
|
||||
# Action prediction mode: "diffusion" or "fast"
|
||||
prediction_mode: str = "diffusion"
|
||||
|
||||
# Decoding parameters
|
||||
num_inference_timesteps: int = 10 # Number of ODE solver steps
|
||||
ode_solver_method: str = "euler" # ODE solver method
|
||||
# Tokenizer settings
|
||||
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
|
||||
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
|
||||
|
||||
# ==================== Robot Configuration ====================
|
||||
# Degrees of freedom configuration - defines action space
|
||||
dof_config: dict = field(default_factory=lambda: {
|
||||
"left_ee_pos": 3,
|
||||
"left_ee_rot": 3,
|
||||
"left_gripper": 1,
|
||||
"right_ee_pos": 3,
|
||||
"right_ee_rot": 3,
|
||||
"right_gripper": 1,
|
||||
})
|
||||
|
||||
# Proprioception configuration (typically mirrors dof_config)
|
||||
agent_pos_config: dict = field(default_factory=lambda: {
|
||||
"left_ee_pos": 3,
|
||||
"left_ee_rot": 3,
|
||||
"left_gripper": 1,
|
||||
"right_ee_pos": 3,
|
||||
"right_ee_rot": 3,
|
||||
"right_gripper": 1,
|
||||
})
|
||||
|
||||
# Customized robot configuration
|
||||
enable_customized_robot_config: bool = False
|
||||
customized_robot_config: dict = field(default_factory=lambda: {
|
||||
"name": "",
|
||||
"customized_dof_config": {},
|
||||
"customized_agent_pos_config": {},
|
||||
})
|
||||
|
||||
# Normalization statistics path
|
||||
norm_stats_path: str | None = None
|
||||
|
||||
# ==================== MoE Configuration ====================
|
||||
num_experts: int = 4
|
||||
attention_moe: bool = False
|
||||
mlp_moe: bool = False
|
||||
|
||||
# ==================== Finetuning Settings ====================
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False # wall-x trains more components
|
||||
train_action_head: bool = True
|
||||
|
||||
# Cache
|
||||
use_cache: bool = True
|
||||
|
||||
# ==================== Optimizer Presets ====================
|
||||
optimizer_lr: float = 2e-5
|
||||
@@ -191,44 +72,6 @@ class WallXConfig(PreTrainedConfig):
|
||||
scheduler_decay_steps: int = 100000
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
# ==================== Dataset Configuration ====================
|
||||
# Dataset-specific normalization statistics
|
||||
action_statistics: dict = field(default_factory=dict)
|
||||
|
||||
# Data configuration
|
||||
data_config: dict = field(default_factory=lambda: {
|
||||
"use_lerobot": True,
|
||||
"lerobot_config": {
|
||||
"repo_id": "",
|
||||
"root": None,
|
||||
"episodes": None,
|
||||
"image_transforms": None,
|
||||
"delta_timestamps": None,
|
||||
"tolerance_s": 1e-4,
|
||||
"revision": None,
|
||||
"force_cache_sync": False,
|
||||
"download_videos": True,
|
||||
"video_backend": None,
|
||||
},
|
||||
"action_horizon": 32,
|
||||
"train_test_split": 0.95,
|
||||
"obs_action_keys": [],
|
||||
"predict_action_keys": [],
|
||||
"resolution": {
|
||||
"face_view": 256,
|
||||
"left_wrist_view": 256,
|
||||
"right_wrist_view": 256,
|
||||
"move1_view": 256,
|
||||
"move2_view": 256,
|
||||
"top_view": 256,
|
||||
"wall_view": 256,
|
||||
"multi_modal": 256,
|
||||
},
|
||||
})
|
||||
|
||||
# ==================== Resume Configuration ====================
|
||||
resume_config: dict | None = field(default_factory=lambda: None)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -239,243 +82,55 @@ class WallXConfig(PreTrainedConfig):
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
|
||||
if self.prediction_mode not in ["flow", "fast"]:
|
||||
if self.prediction_mode not in ["diffusion", "fast"]:
|
||||
raise ValueError(
|
||||
f"prediction_mode must be 'flow' or 'fast', got {self.prediction_mode}"
|
||||
)
|
||||
|
||||
# Validate dof_config total doesn't exceed max_action_dim
|
||||
total_dof = sum(self.dof_config.values())
|
||||
if total_dof > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})"
|
||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
||||
)
|
||||
|
||||
# Sync prediction_mode with use_fast_tokenizer
|
||||
if self.use_fast_tokenizer:
|
||||
self.prediction_mode = "fast"
|
||||
else:
|
||||
self.prediction_mode = "flow"
|
||||
self.prediction_mode = "diffusion"
|
||||
|
||||
def get_train_config(self) -> dict:
|
||||
"""
|
||||
Extract the complete train_config dictionary matching the YAML training configuration format.
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"Wall-X policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
This method constructs the full train_config from WallXConfig fields, suitable for
|
||||
training scripts and Qwen2_5_VLMoEForAction.from_pretrained.
|
||||
|
||||
Returns:
|
||||
dict: Complete training configuration matching YAML structure.
|
||||
"""
|
||||
# Build customized_robot_config
|
||||
if self.enable_customized_robot_config and self.customized_robot_config:
|
||||
customized_robot_config = {
|
||||
"name": self.customized_robot_config.get("name", ""),
|
||||
"customized_dof_config": self.customized_robot_config.get(
|
||||
"customized_dof_config", self.dof_config
|
||||
),
|
||||
"customized_agent_pos_config": self.customized_robot_config.get(
|
||||
"customized_agent_pos_config", self.agent_pos_config
|
||||
),
|
||||
}
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
else:
|
||||
customized_robot_config = {
|
||||
"name": self.data_config.get("lerobot_config", {}).get("repo_id", ""),
|
||||
"customized_dof_config": self.dof_config,
|
||||
"customized_agent_pos_config": self.agent_pos_config,
|
||||
}
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
train_config = {
|
||||
# Model and paths configuration
|
||||
"log_name": self.log_name,
|
||||
"log_project": self.log_project,
|
||||
"model_type": self.model_type,
|
||||
"pretrained_wallx_path": self.pretrained_wallx_path,
|
||||
"save_path": self.save_path,
|
||||
"use_fast_tokenizer": self.use_fast_tokenizer,
|
||||
"action_tokenizer_path": self.action_tokenizer_path,
|
||||
|
||||
# Profiling configuration
|
||||
"profile": self.profile,
|
||||
"profile_save_path": self.profile_save_path,
|
||||
"profile_wait_iters": self.profile_wait_iters,
|
||||
"profile_warmup_iters": self.profile_warmup_iters,
|
||||
"profile_active_iters": self.profile_active_iters,
|
||||
|
||||
# Training hyperparameters
|
||||
"num_warmup_steps": self.num_warmup_steps,
|
||||
"num_training_steps": self.num_training_steps,
|
||||
"learning_rate": self.learning_rate,
|
||||
"min_lr": self.min_lr,
|
||||
"num_epoch": self.num_epoch,
|
||||
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
||||
"batch_size_per_gpu": self.batch_size_per_gpu,
|
||||
"padding_side": self.padding_side,
|
||||
"epoch_save_interval": self.epoch_save_interval,
|
||||
|
||||
# Training optimization
|
||||
"FSDP2": self.fsdp2,
|
||||
"torch_compile": self.torch_compile,
|
||||
|
||||
# Robot configuration
|
||||
"dof_config": self.dof_config,
|
||||
"agent_pos_config": self.agent_pos_config,
|
||||
|
||||
# Normalization stats
|
||||
"norm_stats_path": self.norm_stats_path,
|
||||
|
||||
# Customized robot config
|
||||
"enable_customized_robot_config": self.enable_customized_robot_config,
|
||||
"customized_robot_config": customized_robot_config,
|
||||
|
||||
# Resume configuration
|
||||
"resume": self.resume_config,
|
||||
|
||||
# Data configuration
|
||||
"data": self.data_config,
|
||||
}
|
||||
|
||||
return train_config
|
||||
|
||||
def get_dataload_config(self) -> dict:
|
||||
"""
|
||||
Extract data loading configuration from config.
|
||||
|
||||
Returns:
|
||||
dict: Data loading configuration for preprocessing.
|
||||
"""
|
||||
return {
|
||||
"action_horizon": self.data_config.get("action_horizon", self.chunk_size),
|
||||
"train_test_split": self.data_config.get("train_test_split", 0.95),
|
||||
"split_seed": 42,
|
||||
"predict_action_keys": self.data_config.get("predict_action_keys", []),
|
||||
"obs_action_keys": self.data_config.get("obs_action_keys", []),
|
||||
"resolution": self.data_config.get("resolution", {}),
|
||||
"priority_order": None,
|
||||
"max_length": self.tokenizer_max_length,
|
||||
}
|
||||
|
||||
def get_lerobot_config(self) -> dict:
|
||||
"""
|
||||
Extract LeRobot dataset configuration.
|
||||
|
||||
Returns:
|
||||
dict: LeRobot dataset configuration.
|
||||
"""
|
||||
return self.data_config.get("lerobot_config", {})
|
||||
|
||||
@classmethod
|
||||
def from_yaml_dict(cls, yaml_dict: dict) -> "WallXConfig":
|
||||
"""
|
||||
Create a WallXConfig from a YAML configuration dictionary.
|
||||
|
||||
Args:
|
||||
yaml_dict: Dictionary loaded from YAML training config file.
|
||||
|
||||
Returns:
|
||||
WallXConfig instance with values from YAML.
|
||||
"""
|
||||
config_kwargs = {}
|
||||
|
||||
# Model and paths
|
||||
if "log_name" in yaml_dict:
|
||||
config_kwargs["log_name"] = yaml_dict["log_name"]
|
||||
if "log_project" in yaml_dict:
|
||||
config_kwargs["log_project"] = yaml_dict["log_project"]
|
||||
if "model_type" in yaml_dict:
|
||||
config_kwargs["model_type"] = yaml_dict["model_type"]
|
||||
if "pretrained_wallx_path" in yaml_dict:
|
||||
config_kwargs["pretrained_wallx_path"] = yaml_dict["pretrained_wallx_path"]
|
||||
if "save_path" in yaml_dict:
|
||||
config_kwargs["save_path"] = yaml_dict["save_path"]
|
||||
if "use_fast_tokenizer" in yaml_dict:
|
||||
config_kwargs["use_fast_tokenizer"] = yaml_dict["use_fast_tokenizer"]
|
||||
if "action_tokenizer_path" in yaml_dict:
|
||||
config_kwargs["action_tokenizer_path"] = yaml_dict["action_tokenizer_path"]
|
||||
|
||||
# Profiling
|
||||
if "profile" in yaml_dict:
|
||||
config_kwargs["profile"] = yaml_dict["profile"]
|
||||
if "profile_save_path" in yaml_dict:
|
||||
config_kwargs["profile_save_path"] = yaml_dict["profile_save_path"]
|
||||
if "profile_wait_iters" in yaml_dict:
|
||||
config_kwargs["profile_wait_iters"] = yaml_dict["profile_wait_iters"]
|
||||
if "profile_warmup_iters" in yaml_dict:
|
||||
config_kwargs["profile_warmup_iters"] = yaml_dict["profile_warmup_iters"]
|
||||
if "profile_active_iters" in yaml_dict:
|
||||
config_kwargs["profile_active_iters"] = yaml_dict["profile_active_iters"]
|
||||
|
||||
# Training hyperparameters
|
||||
if "num_warmup_steps" in yaml_dict:
|
||||
config_kwargs["num_warmup_steps"] = yaml_dict["num_warmup_steps"]
|
||||
config_kwargs["scheduler_warmup_steps"] = yaml_dict["num_warmup_steps"]
|
||||
if "num_training_steps" in yaml_dict:
|
||||
config_kwargs["num_training_steps"] = yaml_dict["num_training_steps"]
|
||||
config_kwargs["scheduler_decay_steps"] = yaml_dict["num_training_steps"]
|
||||
if "learning_rate" in yaml_dict:
|
||||
config_kwargs["learning_rate"] = yaml_dict["learning_rate"]
|
||||
config_kwargs["optimizer_lr"] = yaml_dict["learning_rate"]
|
||||
if "min_lr" in yaml_dict:
|
||||
config_kwargs["min_lr"] = yaml_dict["min_lr"]
|
||||
config_kwargs["scheduler_decay_lr"] = yaml_dict["min_lr"]
|
||||
if "num_epoch" in yaml_dict:
|
||||
config_kwargs["num_epoch"] = yaml_dict["num_epoch"]
|
||||
if "gradient_accumulation_steps" in yaml_dict:
|
||||
config_kwargs["gradient_accumulation_steps"] = yaml_dict["gradient_accumulation_steps"]
|
||||
if "batch_size_per_gpu" in yaml_dict:
|
||||
config_kwargs["batch_size_per_gpu"] = yaml_dict["batch_size_per_gpu"]
|
||||
if "padding_side" in yaml_dict:
|
||||
config_kwargs["padding_side"] = yaml_dict["padding_side"]
|
||||
if "epoch_save_interval" in yaml_dict:
|
||||
config_kwargs["epoch_save_interval"] = yaml_dict["epoch_save_interval"]
|
||||
|
||||
# Training optimization
|
||||
if "FSDP2" in yaml_dict:
|
||||
config_kwargs["fsdp2"] = yaml_dict["FSDP2"]
|
||||
if "torch_compile" in yaml_dict:
|
||||
config_kwargs["torch_compile"] = yaml_dict["torch_compile"]
|
||||
|
||||
# Robot configuration
|
||||
if "dof_config" in yaml_dict:
|
||||
config_kwargs["dof_config"] = yaml_dict["dof_config"]
|
||||
if "agent_pos_config" in yaml_dict:
|
||||
config_kwargs["agent_pos_config"] = yaml_dict["agent_pos_config"]
|
||||
|
||||
# Normalization stats
|
||||
if "norm_stats_path" in yaml_dict:
|
||||
config_kwargs["norm_stats_path"] = yaml_dict["norm_stats_path"]
|
||||
|
||||
# Customized robot config
|
||||
if "enable_customized_robot_config" in yaml_dict:
|
||||
config_kwargs["enable_customized_robot_config"] = yaml_dict["enable_customized_robot_config"]
|
||||
if "customized_robot_config" in yaml_dict:
|
||||
config_kwargs["customized_robot_config"] = yaml_dict["customized_robot_config"]
|
||||
|
||||
# Resume config
|
||||
if "resume" in yaml_dict:
|
||||
config_kwargs["resume_config"] = yaml_dict["resume"]
|
||||
|
||||
# Data configuration
|
||||
if "data" in yaml_dict:
|
||||
data = yaml_dict["data"]
|
||||
data_config = {
|
||||
"use_lerobot": data.get("use_lerobot", True),
|
||||
"action_horizon": data.get("action_horizon", 32),
|
||||
"train_test_split": data.get("train_test_split", 0.95),
|
||||
"obs_action_keys": data.get("obs_action_keys", []),
|
||||
"predict_action_keys": data.get("predict_action_keys", []),
|
||||
"resolution": data.get("resolution", {}),
|
||||
}
|
||||
if "lerobot_config" in data:
|
||||
data_config["lerobot_config"] = data["lerobot_config"]
|
||||
config_kwargs["data_config"] = data_config
|
||||
|
||||
# Set chunk_size from action_horizon
|
||||
if "action_horizon" in data:
|
||||
config_kwargs["chunk_size"] = data["action_horizon"]
|
||||
config_kwargs["n_action_steps"] = data["action_horizon"]
|
||||
|
||||
return cls(**config_kwargs)
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
|
||||
f"Either reduce action dimension or increase max_action_dim in config."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
@@ -496,7 +151,7 @@ class WallXConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return [0]
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
|
||||
@@ -16,15 +16,9 @@
|
||||
|
||||
"""
|
||||
Wall-X Constants and Configuration Data.
|
||||
|
||||
Contains dataset names, key mappings, frequency mappings, and action statistics
|
||||
for cross-embodiment robotic control.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# Add wall-x repo to path if available
|
||||
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION
|
||||
|
||||
CAMERA_NAME_MAPPING = {
|
||||
"face_view": "front view",
|
||||
@@ -35,3 +29,15 @@ CAMERA_NAME_MAPPING = {
|
||||
"wall_view": "wall view",
|
||||
"top_view": "top view",
|
||||
}
|
||||
|
||||
RESOLUTION = 256
|
||||
|
||||
# Parameters for preprocessing
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MIN_PIXELS = 4 * 28 * 28
|
||||
IMAGE_FACTOR = 28
|
||||
PRIORITY_ORDER = None
|
||||
GENERATE_SUBTASK_RATIO = 0.0
|
||||
MODEL_TYPE = "qwen2_5"
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 768
|
||||
@@ -38,6 +38,7 @@ import builtins
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
from os import PathLike
|
||||
import sys
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -57,16 +58,8 @@ from torchdiffeq import odeint
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import is_torchdynamo_compiling, logging
|
||||
from transformers import AutoProcessor, BatchFeature
|
||||
from qwen_vl_utils.vision_process import smart_resize
|
||||
@@ -80,29 +73,18 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
|
||||
from lerobot.policies.wall_x.utils import *
|
||||
from lerobot.policies.wall_x.constant import *
|
||||
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2RMSNorm,
|
||||
)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLRotaryEmbedding,
|
||||
Qwen2_5_VLPreTrainedModel,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
|
||||
Qwen2_5_VisionTransformerPretrainedModel,
|
||||
Qwen2_5_VLDecoderLayer_with_MoE,
|
||||
Qwen2_5_VLACausalLMOutputWithPast,
|
||||
Qwen2_5_VLMoEModel,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Add wall-x repo to path if available
|
||||
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
|
||||
if WALL_X_PATH.exists():
|
||||
sys.path.insert(0, str(WALL_X_PATH))
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""Sinusoidal positional embedding for diffusion timesteps."""
|
||||
@@ -129,7 +111,7 @@ class ActionHead(nn.Module):
|
||||
for action sequence prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, config: WallXConfig):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@@ -138,10 +120,9 @@ class ActionHead(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
# Beta distribution for noise scheduling
|
||||
noise_config = config.noise_scheduler
|
||||
self.beta_alpha = noise_config.get("beta_alpha", 1.5)
|
||||
self.beta_beta = noise_config.get("beta_beta", 1.0)
|
||||
self.s = noise_config.get("s", 0.999)
|
||||
self.beta_alpha = 1.5
|
||||
self.beta_beta = 1.0
|
||||
self.s = 0.999
|
||||
|
||||
# Sinusoidal timestep embedding
|
||||
self.time_embed = SinusoidalPosEmb(config.hidden_size)
|
||||
@@ -277,12 +258,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_path,
|
||||
train_config,
|
||||
config_path=None,
|
||||
processor_path=None,
|
||||
pretrained_name_or_path,
|
||||
config=None,
|
||||
action_tokenizer_path=None,
|
||||
**kwargs,
|
||||
cache_dir: str | PathLike | None = None,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: str | bool | None = None,
|
||||
revision: str = "main",
|
||||
strict: bool = False,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""
|
||||
Load model from pretrained model path.
|
||||
@@ -290,17 +275,24 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
Args:
|
||||
pretrained_model_path (str): Model directory path containing model.safetensors file
|
||||
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
|
||||
processor_path (str, optional): Processor path, if None will load from default config
|
||||
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Qwen2_5_VLMoEForAction: Loaded model instance
|
||||
"""
|
||||
# Load model components from pretrained path
|
||||
config_path = os.path.join(pretrained_model_path, "config.json")
|
||||
config = cls.config_class.from_pretrained(config_path)
|
||||
processor = AutoProcessor.from_pretrained(pretrained_model_path, use_fast=True)
|
||||
if config is None:
|
||||
config = cls.config_class.from_pretrained(
|
||||
pretrained_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
|
||||
if action_tokenizer_path is not None:
|
||||
processor.action_processor = AutoProcessor.from_pretrained(
|
||||
action_tokenizer_path, trust_remote_code=True
|
||||
@@ -312,18 +304,37 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
# Resize token embeddings to match processor tokenizer vocabulary size
|
||||
model.resize_token_embeddings(len(processor.tokenizer))
|
||||
|
||||
# Load model state dict from safetensors file
|
||||
safetensor_files = glob.glob(
|
||||
os.path.join(pretrained_model_path, "*.safetensors")
|
||||
# Try to load the model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
from safetensors.torch import load_file
|
||||
|
||||
sd = load_file(resolved_file)
|
||||
print("✓ Loaded state dict from model.safetensors")
|
||||
except Exception as e:
|
||||
print(f"Could not load state dict from remote files: {e}")
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
state_dict = {}
|
||||
for file in safetensor_files:
|
||||
sd = load_file(file, device="cpu")
|
||||
# filter normalizer statistic params
|
||||
del_keys = []
|
||||
for key in sd.keys():
|
||||
if "action_preprocessor.normalizer" in key:
|
||||
print(f"filter load model weight {key}")
|
||||
del_keys.append(key)
|
||||
for key in del_keys:
|
||||
del sd[key]
|
||||
@@ -730,7 +741,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
dataset_names: Optional[str] = None,
|
||||
dof_mask: Optional[torch.FloatTensor] = None,
|
||||
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
@@ -763,7 +773,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
rope_deltas (torch.LongTensor, optional): RoPE position deltas
|
||||
cache_position (torch.LongTensor, optional): Cache position indices
|
||||
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
|
||||
dataset_names (str, optional): Names of datasets in the current batch
|
||||
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens
|
||||
agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data
|
||||
**kwargs: Additional keyword arguments
|
||||
@@ -872,7 +881,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
)
|
||||
proprioception = self.action_preprocessor.proprioception_proj(
|
||||
proprioception,
|
||||
dataset_names,
|
||||
agent_pos_mask,
|
||||
use_history=proprioception.shape[1] > 1,
|
||||
)
|
||||
@@ -909,7 +917,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
)
|
||||
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||
noisy_action_emb, flow = self.action_preprocessor(
|
||||
action_chunk, dataset_names, dof_mask
|
||||
action_chunk, dof_mask
|
||||
)
|
||||
mask = input_ids == self.action_token_id_set["action_token_id"]
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
@@ -953,18 +961,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
# Compute losses if labels are provided
|
||||
if labels is not None:
|
||||
loss = 0
|
||||
action_accuracy = 0
|
||||
unique_datasets_name = list(set(dataset_names))
|
||||
|
||||
# Initialize per-dataset loss tracking dictionaries
|
||||
channel_loss_dict = {
|
||||
dataset_name: torch.tensor(0.0, device=logits.device)
|
||||
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
|
||||
}
|
||||
channel_loss_count_dict = {
|
||||
dataset_name: torch.tensor(0, device=logits.device)
|
||||
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
|
||||
}
|
||||
|
||||
# Compute standard cross-entropy loss for language modeling
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
@@ -982,22 +978,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
else torch.tensor(0.0, device=shift_logits.device)
|
||||
)
|
||||
|
||||
# Compute per-dataset channel losses
|
||||
_cross_entropy_loss = _cross_entropy_loss.view(batch_size, seq_length - 1)
|
||||
non_ignored_mask = non_ignored_mask.view(batch_size, seq_length - 1)
|
||||
for dataset_name_i in unique_datasets_name:
|
||||
dataset_mask = torch.tensor(
|
||||
[name == dataset_name_i for name in dataset_names],
|
||||
device=logits.device,
|
||||
)
|
||||
combined_mask = dataset_mask.unsqueeze(1) & non_ignored_mask
|
||||
channel_loss_dict[dataset_name_i] = (
|
||||
_cross_entropy_loss[combined_mask].sum()
|
||||
if combined_mask.any()
|
||||
else torch.tensor(0.0, device=shift_logits.device)
|
||||
)
|
||||
channel_loss_count_dict[dataset_name_i] += combined_mask.sum()
|
||||
|
||||
# Add cross-entropy loss to total loss if valid
|
||||
if not torch.isnan(cross_entropy_loss):
|
||||
loss += cross_entropy_loss
|
||||
@@ -1005,20 +985,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
with torch.no_grad():
|
||||
cross_entropy_loss.detach()
|
||||
|
||||
# Compute action token prediction accuracy
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
action_preds = shift_logits.argmax(dim=-1)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
if self.use_fast_tokenizer:
|
||||
action_mask = (
|
||||
shift_labels > self.action_token_id_set["fast_action_token_list"][0]
|
||||
)
|
||||
correct_preds = (action_preds == shift_labels) & action_mask
|
||||
action_accuracy = (
|
||||
correct_preds.sum().float() / action_mask.sum().float()
|
||||
)
|
||||
channel_loss_dict["action_accuracy"] = action_accuracy
|
||||
|
||||
if action_chunk is not None:
|
||||
action_mask = input_ids == self.action_token_id_set["action_token_id"]
|
||||
if action_mask.any():
|
||||
@@ -1101,7 +1067,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
num_inference_timesteps: Optional[int] = 10,
|
||||
dataset_names: Optional[str] = None,
|
||||
dof_mask: Optional[torch.FloatTensor] = None,
|
||||
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
||||
re_generate: bool = False,
|
||||
@@ -1140,7 +1105,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
cache_position (torch.LongTensor, optional): Cache position indices
|
||||
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
|
||||
num_inference_timesteps (int, optional): Number of diffusion inference steps
|
||||
dataset_names (str, optional): Dataset names for normalization
|
||||
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask
|
||||
agent_pos_mask (torch.FloatTensor, optional): Agent position mask
|
||||
re_generate (bool, optional): Whether to use sampling for regeneration
|
||||
@@ -1239,7 +1203,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
)
|
||||
proprio_embed = self.action_preprocessor.proprioception_proj(
|
||||
proprioception,
|
||||
dataset_names,
|
||||
agent_pos_mask,
|
||||
use_history=proprioception.shape[1] > 1,
|
||||
)
|
||||
@@ -1339,7 +1302,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
"dof_mask": dof_mask,
|
||||
"agent_pos_mask": agent_pos_mask,
|
||||
"proprioception": proprioception,
|
||||
"dataset_names": dataset_names,
|
||||
}
|
||||
|
||||
# Generate output tokens
|
||||
@@ -1545,7 +1507,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
image_grid_thw=None,
|
||||
video_grid_thw=None,
|
||||
second_per_grid_ts=None,
|
||||
dataset_names=None,
|
||||
proprioception=None,
|
||||
dof_mask=None,
|
||||
agent_pos_mask=None,
|
||||
@@ -1572,7 +1533,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
image_grid_thw: Image grid dimensions
|
||||
video_grid_thw: Video grid dimensions
|
||||
second_per_grid_ts: Time interval per temporal grid
|
||||
dataset_names: Dataset names for processing
|
||||
proprioception: Proprioceptive sensor data
|
||||
dof_mask: Degrees of freedom mask
|
||||
agent_pos_mask: Agent position mask
|
||||
@@ -1677,7 +1637,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
"cache_position": cache_position,
|
||||
"second_per_grid_ts": second_per_grid_ts,
|
||||
"proprioception": proprioception,
|
||||
"dataset_names": dataset_names,
|
||||
"dof_mask": dof_mask,
|
||||
"agent_pos_mask": agent_pos_mask,
|
||||
}
|
||||
@@ -1866,150 +1825,14 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Initialize VLM wrapper
|
||||
self.model = Qwen2_5_VLMoEForAction(config)
|
||||
# Initialize the wall-x model
|
||||
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
|
||||
self.model.to(config.device)
|
||||
# Convert to bfloat16 for Flash Attention compatibility
|
||||
self.model.to_bfloat16_for_selected_params()
|
||||
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
Load WallXPolicy from a pretrained model path.
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path: Path to pretrained model or model identifier
|
||||
config: Optional configuration object
|
||||
force_download: Force download even if cached
|
||||
resume_download: Resume interrupted download
|
||||
proxies: Proxy configuration
|
||||
token: Authentication token
|
||||
cache_dir: Cache directory path
|
||||
local_files_only: Only use local files
|
||||
revision: Model revision
|
||||
strict: Strict loading of state dict
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
WallXPolicy: Loaded policy instance
|
||||
"""
|
||||
print(
|
||||
"Loading Wall-X model for cross-embodiment robotic control.\n"
|
||||
"This implementation integrates Qwen2.5-VL with flow matching for action prediction."
|
||||
)
|
||||
|
||||
if pretrained_name_or_path is None:
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
|
||||
# Use provided config if available, otherwise load from pretrained path
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize model without loading weights
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Load and remap the state dict
|
||||
try:
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
original_state_dict = load_file(resolved_file)
|
||||
print("✓ Loaded state dict from model.safetensors")
|
||||
except Exception:
|
||||
print(f"Could not load state dict: {e}")
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# Filter out normalizer statistics if present
|
||||
filtered_state_dict = {}
|
||||
for key, value in original_state_dict.items():
|
||||
if "action_preprocessor.normalizer" not in key:
|
||||
filtered_state_dict[key] = value
|
||||
else:
|
||||
print(f"Filtered key: {key}")
|
||||
|
||||
# Add "model." prefix for keys that don't have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
|
||||
for key, value in filtered_state_dict.items():
|
||||
if not key.startswith("model."):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
if remap_count > 0:
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset action queue."""
|
||||
self._queues = {
|
||||
@@ -2022,88 +1845,50 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
|
||||
def preprocess_inputs(
|
||||
self,
|
||||
batch: List[Dict[str, Any]],
|
||||
config: Dict[str, Any],
|
||||
dataload_config: Dict[str, Any],
|
||||
lerobot_config: Dict[str, Any],
|
||||
processor: Any,
|
||||
action_tokenizer: Optional[Any] = None,
|
||||
camera_keys: Optional[List[str]] = None,
|
||||
batch: Dict[str, Any],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Convert a batch of LeRobot dataset items to Wall-X model input format.
|
||||
|
||||
This is the batch version of convert_lerobot_to_wallx_format, processing multiple
|
||||
samples together for efficient batched inference or training.
|
||||
This processes a batched dictionary where tensors have batch dimension first.
|
||||
|
||||
Args:
|
||||
batch: List of items from LeRobot dataset
|
||||
config: Model configuration dict
|
||||
dataload_config: Data loading configuration
|
||||
norm_stats: Normalization statistics
|
||||
lerobot_config: LeRobot config containing 'repo_id'
|
||||
processor: Hugging Face processor for tokenization
|
||||
action_tokenizer: Optional action tokenizer
|
||||
camera_keys: List of camera keys in the dataset
|
||||
batch: Dictionary with batched tensors:
|
||||
- "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim)
|
||||
- "action": (batch_size, chunk_size, action_dim)
|
||||
- "observation.images.<key>": (batch_size, C, H, W)
|
||||
- "task": List[str] of length batch_size
|
||||
|
||||
Returns:
|
||||
BatchFeature containing batched model inputs
|
||||
|
||||
Example:
|
||||
>>> batch = [dataset[i] for i in range(8)]
|
||||
>>> result = preprocess_inputs(
|
||||
... batch=batch,
|
||||
... config=config,
|
||||
... dataload_config=dataload_config,
|
||||
... norm_stats=norm_stats,
|
||||
... lerobot_config={'repo_id': 'lerobot/aloha_mobile_cabinet'},
|
||||
... processor=processor,
|
||||
... )
|
||||
"""
|
||||
repo_id = lerobot_config["repo_id"]
|
||||
use_fast_tokenizer = config.get("use_fast_tokenizer", False)
|
||||
use_fast_tokenizer = self.config.use_fast_tokenizer
|
||||
|
||||
# Get key mappings
|
||||
cam_key_mapping = KEY_MAPPINGS[repo_id]["camera"]
|
||||
state_key = KEY_MAPPINGS[repo_id]["state"]
|
||||
action_key = KEY_MAPPINGS[repo_id]["action"]
|
||||
|
||||
# Build data config
|
||||
data_config = X2RDataProcessingConfig().update(
|
||||
train_test_split=dataload_config.get("train_test_split", 0.95),
|
||||
split_seed=dataload_config.get("split_seed", 42),
|
||||
predict_action_keys=dataload_config.get("predict_action_keys", []),
|
||||
obs_action_keys=dataload_config.get("obs_action_keys", []),
|
||||
resolution=dataload_config.get("resolution", None),
|
||||
priority_order=dataload_config.get("priority_order", None),
|
||||
)
|
||||
|
||||
if camera_keys is None:
|
||||
camera_keys = list(cam_key_mapping.keys())
|
||||
# Get batch size from state tensor
|
||||
batch_size = batch[OBS_STATE].shape[0]
|
||||
|
||||
# ==================== PROCESS ALL SAMPLES ====================
|
||||
all_image_inputs = []
|
||||
all_texts = []
|
||||
all_agent_pos = []
|
||||
all_actions = []
|
||||
all_frame_indices = []
|
||||
|
||||
for data in batch:
|
||||
# Find image keys in batch
|
||||
img_keys = [key for key in self.config.image_features if key in batch]
|
||||
|
||||
for i in range(batch_size):
|
||||
# Vision preprocessing per sample
|
||||
processed_frames = []
|
||||
orig_height, orig_width = None, None
|
||||
resized_height, resized_width = None, None
|
||||
|
||||
for key in camera_keys:
|
||||
current_obs = data[key].clone()
|
||||
for key in img_keys:
|
||||
current_obs = batch[key][i].clone() # (C, H, W)
|
||||
if current_obs.dim() == 3:
|
||||
current_obs = current_obs.permute(1, 2, 0)
|
||||
current_obs = current_obs.permute(1, 2, 0) # (H, W, C)
|
||||
|
||||
img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy())
|
||||
orig_width, orig_height = img_pil.size
|
||||
|
||||
cam_name = cam_key_mapping.get(key, key)
|
||||
target_size = data_config.resolution.get(cam_name, -1)
|
||||
target_size = RESOLUTION
|
||||
if target_size != -1:
|
||||
if orig_width > orig_height:
|
||||
new_width = target_size
|
||||
@@ -2117,9 +1902,9 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
resized_height, resized_width = smart_resize(
|
||||
current_height,
|
||||
current_width,
|
||||
factor=data_config.image_factor,
|
||||
min_pixels=data_config.min_pixels,
|
||||
max_pixels=data_config.max_pixels,
|
||||
factor=IMAGE_FACTOR,
|
||||
min_pixels=MIN_PIXELS,
|
||||
max_pixels=MAX_PIXELS,
|
||||
)
|
||||
resized_img = img_pil.resize((resized_width, resized_height))
|
||||
processed_frames.append(resized_img)
|
||||
@@ -2127,33 +1912,30 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
all_image_inputs.append(processed_frames)
|
||||
|
||||
# Text preprocessing
|
||||
frame_index = data["frame_index"]
|
||||
instruction_info = {"instruction": data["task"]}
|
||||
task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"]
|
||||
instruction_info = {"instruction": task_text}
|
||||
|
||||
frame_index = batch["frame_index"][i] if "frame_index" in batch else 0
|
||||
complete_text, _ = get_wallx_normal_text(
|
||||
instruction_info,
|
||||
dataload_config.get("action_horizon", 33) - 1,
|
||||
self.config.chunk_size,
|
||||
frame_index,
|
||||
data_config.priority_order,
|
||||
cam_key_mapping,
|
||||
generate_subtask_ratio=data_config.generate_subtask_ratio,
|
||||
PRIORITY_ORDER,
|
||||
img_keys,
|
||||
generate_subtask_ratio=GENERATE_SUBTASK_RATIO,
|
||||
)
|
||||
|
||||
text = process_grounding_points(
|
||||
complete_text, orig_height, orig_width, resized_height, resized_width,
|
||||
data_config.model_type
|
||||
MODEL_TYPE
|
||||
)
|
||||
all_texts.append(text)
|
||||
|
||||
# Collect raw values
|
||||
all_agent_pos.append(data[state_key])
|
||||
all_actions.append(data[action_key])
|
||||
all_frame_indices.append(frame_index)
|
||||
|
||||
# Stack agent_pos
|
||||
agent_pos = torch.stack(all_agent_pos)
|
||||
# ==================== PROCESS AGENT POS ====================
|
||||
agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
|
||||
if agent_pos.dim() == 2:
|
||||
agent_pos = agent_pos.unsqueeze(1)
|
||||
agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim)
|
||||
agent_pos_mask = (~torch.isnan(agent_pos)).float()
|
||||
agent_pos = agent_pos.nan_to_num(nan=0.0)
|
||||
|
||||
@@ -2161,15 +1943,15 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
pad_size = 20 - agent_pos.shape[-1]
|
||||
agent_pos = torch.cat([
|
||||
agent_pos,
|
||||
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size)
|
||||
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device)
|
||||
], dim=-1)
|
||||
agent_pos_mask = torch.cat([
|
||||
agent_pos_mask,
|
||||
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size)
|
||||
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device)
|
||||
], dim=-1)
|
||||
|
||||
# Stack actions
|
||||
action = torch.stack(all_actions)
|
||||
# ==================== PROCESS ACTIONS ====================
|
||||
action = batch[ACTION] # (batch_size, chunk_size, action_dim)
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1)
|
||||
dof_mask = (~torch.isnan(action)).float()
|
||||
@@ -2179,36 +1961,35 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
pad_size = 20 - action.shape[-1]
|
||||
action = torch.cat([
|
||||
action,
|
||||
torch.zeros(action.shape[0], action.shape[1], pad_size)
|
||||
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
|
||||
], dim=-1)
|
||||
dof_mask = torch.cat([
|
||||
dof_mask,
|
||||
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size)
|
||||
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
|
||||
], dim=-1)
|
||||
|
||||
# ==================== ACTION TOKEN REPLACEMENT ====================
|
||||
all_texts = replace_action_token(
|
||||
all_texts,
|
||||
action,
|
||||
action_tokenizer if use_fast_tokenizer else None,
|
||||
[repo_id] * len(batch),
|
||||
self.model.action_tokenizer if use_fast_tokenizer else None,
|
||||
dof_mask,
|
||||
)
|
||||
|
||||
# ==================== TOKENIZATION ====================
|
||||
inputs = preprocesser_call(
|
||||
processor=processor,
|
||||
processor=self.model.processor,
|
||||
text=all_texts,
|
||||
images=all_image_inputs,
|
||||
videos=None,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
max_length=dataload_config.get("max_length", 768),
|
||||
max_length=TOKENIZER_MAX_LENGTH,
|
||||
)
|
||||
|
||||
# ==================== ADDITIONAL INPUTS ====================
|
||||
action_token_id = processor.tokenizer.convert_tokens_to_ids("<|action|>")
|
||||
action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>")
|
||||
moe_token_types = inputs.input_ids == action_token_id
|
||||
|
||||
inputs["proprioception"] = agent_pos
|
||||
@@ -2216,11 +1997,13 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
inputs["action_chunk"] = action
|
||||
inputs["dof_mask"] = dof_mask
|
||||
inputs["moe_token_types"] = moe_token_types
|
||||
inputs["dataset_names"] = [repo_id] * len(batch)
|
||||
inputs["frame_index"] = torch.stack([
|
||||
torch.tensor(fi) if not isinstance(fi, torch.Tensor) else fi
|
||||
for fi in all_frame_indices
|
||||
])
|
||||
inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device)
|
||||
|
||||
# Move all tensors to the correct device
|
||||
device = self.config.device
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs[key] = value.to(device)
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -2232,37 +2015,19 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
batch: Dictionary containing preprocessed inputs from preprocess_inputs()
|
||||
Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw,
|
||||
proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types,
|
||||
dataset_names, etc.
|
||||
etc.
|
||||
|
||||
Returns:
|
||||
tuple: (loss, loss_dict)
|
||||
"""
|
||||
batch = self.preprocess_inputs(
|
||||
batch,
|
||||
self.config,
|
||||
self.dataload_config,
|
||||
self.lerobot_config,
|
||||
self.processor,
|
||||
self.action_tokenizer,
|
||||
self.camera_keys
|
||||
)
|
||||
|
||||
# Call the underlying model's forward with mode="train"
|
||||
outputs = self.model(
|
||||
mode="train",
|
||||
input_ids=batch.get("input_ids"),
|
||||
attention_mask=batch.get("attention_mask"),
|
||||
pixel_values=batch.get("pixel_values"),
|
||||
image_grid_thw=batch.get("image_grid_thw"),
|
||||
pixel_values_videos=batch.get("pixel_values_videos"),
|
||||
video_grid_thw=batch.get("video_grid_thw"),
|
||||
proprioception=batch.get("proprioception"),
|
||||
agent_pos_mask=batch.get("agent_pos_mask"),
|
||||
action_chunk=batch.get("action_chunk"),
|
||||
dof_mask=batch.get("dof_mask"),
|
||||
moe_token_types=batch.get("moe_token_types"),
|
||||
dataset_names=batch.get("dataset_names"),
|
||||
labels=batch.get("labels", batch.get("input_ids")), # Use input_ids as labels if not provided
|
||||
**batch,
|
||||
mode="train"
|
||||
)
|
||||
|
||||
# Extract losses from output
|
||||
@@ -2292,16 +2057,10 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
|
||||
batch = self.preprocess_inputs(
|
||||
batch,
|
||||
self.config,
|
||||
self.dataload_config,
|
||||
self.lerobot_config,
|
||||
self.processor,
|
||||
self.action_tokenizer,
|
||||
self.camera_keys
|
||||
)
|
||||
|
||||
if self.config.prediction_mode == "diffusion":
|
||||
actions = self.model(
|
||||
output = self.model(
|
||||
**batch,
|
||||
action_dim=self.config.max_action_dim,
|
||||
pred_horizon=self.config.chunk_size,
|
||||
@@ -2309,9 +2068,9 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
predict_mode="diffusion"
|
||||
)
|
||||
elif self.config.prediction_mode == "fast":
|
||||
actions = self.model(
|
||||
output = self.model(
|
||||
**batch,
|
||||
action_dim=self.config.action_feature.shape[0],
|
||||
action_dim=self.config.output_features["action"].shape[0],
|
||||
pred_horizon=self.config.chunk_size,
|
||||
mode="predict",
|
||||
predict_mode="fast"
|
||||
@@ -2319,8 +2078,12 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
|
||||
|
||||
# Unpad actions
|
||||
actions = actions[:, :, :self.config.action_feature.shape[0]]
|
||||
# Extract action tensor from output dictionary
|
||||
actions = output["predict_action"]
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
action_dim = self.config.output_features["action"].shape[0]
|
||||
actions = actions[:, :, :action_dim]
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
@@ -29,13 +29,10 @@ from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_wall_x_pre_post_processors(
|
||||
config: WallXConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
@@ -49,7 +46,6 @@ def make_wall_x_pre_post_processors(
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations
|
||||
2. Adding a batch dimension
|
||||
3. Tokenizing language task descriptions
|
||||
4. Normalizing input and output features based on dataset statistics
|
||||
5. Moving all data to the specified device
|
||||
|
||||
@@ -65,25 +61,10 @@ def make_wall_x_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines
|
||||
"""
|
||||
|
||||
# Try to use Qwen processor if available
|
||||
try:
|
||||
from transformers import AutoProcessor
|
||||
tokenizer_name = config.vlm_model_name
|
||||
qwen_available = True
|
||||
except ImportError:
|
||||
tokenizer_name = "Qwen/Qwen2-VL-2B-Instruct" # Fallback
|
||||
qwen_available = False
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
WallXTaskProcessor(), # Process task description
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=tokenizer_name,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
@@ -152,30 +133,3 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="wall_x_image_processor")
|
||||
class WallXImageProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Image processor for Wall-X using Qwen-VL vision processing.
|
||||
|
||||
This handles image formatting according to Qwen-VL requirements.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
from transformers import AutoProcessor
|
||||
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
self.available = True
|
||||
except ImportError:
|
||||
self.available = False
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
# Image processing is handled by the VLM processor
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -33,10 +33,8 @@ from transformers import BatchFeature
|
||||
|
||||
from lerobot.policies.wall_x.constant import (
|
||||
CAMERA_NAME_MAPPING,
|
||||
FREQUENCY_MAPPING,
|
||||
KEY_MAPPINGS,
|
||||
MULTIMODAL_DATASET_NAMES,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -457,7 +455,7 @@ def get_wallx_normal_text(
|
||||
action_chunk_size: int,
|
||||
frame_idx: int,
|
||||
priority_order: Optional[OrderedDict] = None,
|
||||
cam_mapping: Optional[Dict[str, str]] = None,
|
||||
img_keys: Optional[List[str]] = None,
|
||||
generate_subtask_ratio: float = 0.0,
|
||||
) -> Tuple[str, bool]:
|
||||
"""Construct complete multimodal prompt text for Wall-X model.
|
||||
@@ -474,7 +472,7 @@ def get_wallx_normal_text(
|
||||
action_chunk_size: Number of action tokens to generate
|
||||
frame_idx: Current frame index
|
||||
priority_order: Priority order for instruction sampling
|
||||
cam_mapping: Camera name mapping dictionary
|
||||
img_keys: List of image keys
|
||||
generate_subtask_ratio: Probability of generating subtask instead of actions
|
||||
|
||||
Returns:
|
||||
@@ -497,10 +495,10 @@ def get_wallx_normal_text(
|
||||
|
||||
# User request with observation
|
||||
user_request = f"{role_start_symbol}user\nObservation:"
|
||||
if cam_mapping:
|
||||
for _, cam_name in cam_mapping.items():
|
||||
view_name = CAMERA_NAME_MAPPING.get(cam_name, cam_name)
|
||||
user_request += f" {view_name}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
|
||||
if img_keys:
|
||||
img_keys = img_key_mapping(img_keys)
|
||||
for key in img_keys:
|
||||
user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
|
||||
user_request += "\nInstruction:"
|
||||
|
||||
# Get frame-specific instruction
|
||||
@@ -543,6 +541,27 @@ def get_wallx_normal_text(
|
||||
complete_text = prologue + user_message + assistant_output
|
||||
return complete_text, generate_subtask
|
||||
|
||||
def img_key_mapping(img_keys: List[str]) -> List[str]:
|
||||
"""Map image keys to camera names.
|
||||
|
||||
Args:
|
||||
img_keys: List of image keys
|
||||
|
||||
Returns:
|
||||
List of camera names
|
||||
"""
|
||||
processed_img_keys = []
|
||||
for key in img_keys:
|
||||
key = key.replace(OBS_IMAGES + ".", "")
|
||||
if key in CAMERA_NAME_MAPPING:
|
||||
key = CAMERA_NAME_MAPPING[key]
|
||||
else:
|
||||
if 'view' in key:
|
||||
key = key.replace('_', ' ')
|
||||
else:
|
||||
key = key + " view"
|
||||
processed_img_keys.append(key)
|
||||
return processed_img_keys
|
||||
|
||||
def get_action_tokens(
|
||||
normalized_actions: Union[torch.Tensor, List], action_tokenizer
|
||||
@@ -599,7 +618,6 @@ def replace_action_token(
|
||||
text: List[str],
|
||||
norm_action: Optional[torch.Tensor],
|
||||
action_tokenizer,
|
||||
dataset_names: List[str],
|
||||
dof_masks: Optional[torch.Tensor] = None,
|
||||
) -> List[str]:
|
||||
"""Replace action placeholders in text with actual action tokens.
|
||||
@@ -615,17 +633,10 @@ def replace_action_token(
|
||||
List of text strings with action tokens replaced
|
||||
"""
|
||||
# Filter out multimodal dataset names
|
||||
dataset_names = [
|
||||
name for name in dataset_names if name not in MULTIMODAL_DATASET_NAMES
|
||||
]
|
||||
|
||||
# Get required action chunk sizes
|
||||
required_chunk_sizes = [32 for name in dataset_names]
|
||||
|
||||
if action_tokenizer is not None and norm_action is not None:
|
||||
# Extract actions based on chunk sizes and DOF masks
|
||||
norm_action = [
|
||||
action[: required_chunk_sizes[i], dof_masks[i, 0].bool()]
|
||||
action[: 32, dof_masks[i, 0].bool()]
|
||||
for i, action in enumerate(norm_action)
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# Wall-X policy tests
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local Wall-X installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.wall_x import ( # noqa: E402
|
||||
WallXConfig,
|
||||
WallXPolicy,
|
||||
make_wall_x_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
config = WallXConfig(device='cuda')
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(7,),
|
||||
),
|
||||
"observation.images.face_view": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
}
|
||||
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(7,),
|
||||
),
|
||||
}
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
"observation.images.face_view": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = WallXPolicy(config)
|
||||
preprocessor, postprocessor = make_wall_x_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = config.device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.face_view": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
action = postprocessor(action)
|
||||
print(f"Action: {action}")
|
||||
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"Action prediction failed: {e}")
|
||||
raise
|
||||
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
config = make_policy_config(
|
||||
policy_type="wall_x",
|
||||
)
|
||||
print("Config created successfully through factory")
|
||||
print(f" Config type: {type(config).__name__}")
|
||||
except Exception as e:
|
||||
print(f"Config creation failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_policy_instantiation()
|
||||
test_config_creation()
|
||||
Reference in New Issue
Block a user