mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
reduce to least config and params & pass lerobot basic test
This commit is contained in:
committed by
Michel Aractingi
parent
feebca050a
commit
d10d3ef251
@@ -13,13 +13,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.optim.optimizers import AdamWConfig
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
from lerobot.utils.constants import OBS_IMAGES
|
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("wall_x")
|
@PreTrainedConfig.register_subclass("wall_x")
|
||||||
@@ -34,48 +32,15 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
This config supports multi-modal learning with vision, language, and action data.
|
This config supports multi-modal learning with vision, language, and action data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ==================== Model and Paths Configuration ====================
|
|
||||||
# Logging
|
|
||||||
log_name: str = "wall_x_training"
|
|
||||||
log_project: str = "vla_training"
|
|
||||||
model_type: str = "wall-oss"
|
|
||||||
|
|
||||||
# Pretrained model paths
|
|
||||||
pretrained_wallx_path: str | None = None # Path to pretrained Wall-X model
|
|
||||||
save_path: str | None = None # Path to save checkpoints
|
|
||||||
processor_path: str | None = None # Path to processor (defaults to pretrained_wallx_path)
|
|
||||||
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
|
|
||||||
|
|
||||||
# Tokenizer settings
|
|
||||||
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
|
|
||||||
|
|
||||||
# ==================== Profiling Configuration ====================
|
|
||||||
profile: bool = False
|
|
||||||
profile_save_path: str | None = None
|
|
||||||
profile_wait_iters: int = 10
|
|
||||||
profile_warmup_iters: int = 5
|
|
||||||
profile_active_iters: int = 2
|
|
||||||
|
|
||||||
# ==================== Training Hyperparameters ====================
|
|
||||||
num_warmup_steps: int = 100
|
|
||||||
num_training_steps: int = 64000000
|
|
||||||
learning_rate: float = 5e-5
|
|
||||||
min_lr: float = 5e-5
|
|
||||||
num_epoch: int = 100
|
|
||||||
gradient_accumulation_steps: int = 32
|
|
||||||
batch_size_per_gpu: int = 8
|
|
||||||
padding_side: str = "left"
|
|
||||||
epoch_save_interval: int = 10
|
|
||||||
|
|
||||||
# Training optimization
|
|
||||||
fsdp2: bool = False
|
|
||||||
torch_compile: bool = False
|
|
||||||
|
|
||||||
# ==================== Input / Output Structure ====================
|
# ==================== Input / Output Structure ====================
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
chunk_size: int = 32 # action_horizon in wall-x
|
chunk_size: int = 32 # action_horizon in wall-x
|
||||||
n_action_steps: int = 32
|
n_action_steps: int = 32
|
||||||
|
|
||||||
|
# Action dimension - wall-x uses 20
|
||||||
|
max_action_dim: int = 20
|
||||||
|
max_state_dim: int = 20 # For proprioception
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.IDENTITY,
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
@@ -84,101 +49,17 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Action dimension - wall-x uses hardcoded 20
|
# ==================== Action Prediction ====================
|
||||||
max_action_dim: int = 20
|
# Pretrained model paths
|
||||||
max_state_dim: int = 20 # For proprioception
|
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||||
|
|
||||||
# Image preprocessing
|
# Action prediction mode: "diffusion" or "fast"
|
||||||
resize_imgs_with_padding: tuple[int, int] | None = None # wall-x uses Qwen processor
|
prediction_mode: str = "diffusion"
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer settings
|
||||||
tokenizer_max_length: int = 256
|
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
|
||||||
|
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
|
||||||
|
|
||||||
# ==================== Model Architecture ====================
|
|
||||||
vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
||||||
load_vlm_weights: bool = True
|
|
||||||
|
|
||||||
# Vision config
|
|
||||||
vision_config: dict = field(default_factory=lambda: {
|
|
||||||
"depth": 32,
|
|
||||||
"hidden_size": 3584,
|
|
||||||
"hidden_act": "silu",
|
|
||||||
"intermediate_size": 3420,
|
|
||||||
"num_heads": 16,
|
|
||||||
"patch_size": 14,
|
|
||||||
"spatial_merge_size": 2,
|
|
||||||
"temporal_patch_size": 2,
|
|
||||||
"window_size": 112,
|
|
||||||
"out_hidden_size": 3584,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Language model config
|
|
||||||
hidden_size: int = 3584 # 8192 for 7B model
|
|
||||||
intermediate_size: int = 18944 # 29568 for 7B model
|
|
||||||
num_hidden_layers: int = 36 # 80 for 7B model
|
|
||||||
num_attention_heads: int = 28 # 64 for 7B model
|
|
||||||
num_key_value_heads: int = 4 # 8 for 7B model
|
|
||||||
vocab_size: int = 152064
|
|
||||||
|
|
||||||
# ==================== Action Prediction ====================
|
|
||||||
# Action prediction mode: "flow" or "fast"
|
|
||||||
prediction_mode: str = "flow"
|
|
||||||
|
|
||||||
# Flow matching parameters
|
|
||||||
noise_scheduler: dict = field(default_factory=lambda: {
|
|
||||||
"beta_alpha": 1.5, # Beta distribution concentration1
|
|
||||||
"beta_beta": 1.0, # Beta distribution concentration0
|
|
||||||
"s": 0.999, # Scaling factor for time
|
|
||||||
})
|
|
||||||
|
|
||||||
# Decoding parameters
|
|
||||||
num_inference_timesteps: int = 10 # Number of ODE solver steps
|
|
||||||
ode_solver_method: str = "euler" # ODE solver method
|
|
||||||
|
|
||||||
# ==================== Robot Configuration ====================
|
|
||||||
# Degrees of freedom configuration - defines action space
|
|
||||||
dof_config: dict = field(default_factory=lambda: {
|
|
||||||
"left_ee_pos": 3,
|
|
||||||
"left_ee_rot": 3,
|
|
||||||
"left_gripper": 1,
|
|
||||||
"right_ee_pos": 3,
|
|
||||||
"right_ee_rot": 3,
|
|
||||||
"right_gripper": 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Proprioception configuration (typically mirrors dof_config)
|
|
||||||
agent_pos_config: dict = field(default_factory=lambda: {
|
|
||||||
"left_ee_pos": 3,
|
|
||||||
"left_ee_rot": 3,
|
|
||||||
"left_gripper": 1,
|
|
||||||
"right_ee_pos": 3,
|
|
||||||
"right_ee_rot": 3,
|
|
||||||
"right_gripper": 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Customized robot configuration
|
|
||||||
enable_customized_robot_config: bool = False
|
|
||||||
customized_robot_config: dict = field(default_factory=lambda: {
|
|
||||||
"name": "",
|
|
||||||
"customized_dof_config": {},
|
|
||||||
"customized_agent_pos_config": {},
|
|
||||||
})
|
|
||||||
|
|
||||||
# Normalization statistics path
|
|
||||||
norm_stats_path: str | None = None
|
|
||||||
|
|
||||||
# ==================== MoE Configuration ====================
|
|
||||||
num_experts: int = 4
|
|
||||||
attention_moe: bool = False
|
|
||||||
mlp_moe: bool = False
|
|
||||||
|
|
||||||
# ==================== Finetuning Settings ====================
|
|
||||||
freeze_vision_encoder: bool = True
|
|
||||||
train_expert_only: bool = False # wall-x trains more components
|
|
||||||
train_action_head: bool = True
|
|
||||||
|
|
||||||
# Cache
|
|
||||||
use_cache: bool = True
|
|
||||||
|
|
||||||
# ==================== Optimizer Presets ====================
|
# ==================== Optimizer Presets ====================
|
||||||
optimizer_lr: float = 2e-5
|
optimizer_lr: float = 2e-5
|
||||||
@@ -191,44 +72,6 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
scheduler_decay_steps: int = 100000
|
scheduler_decay_steps: int = 100000
|
||||||
scheduler_decay_lr: float = 1e-6
|
scheduler_decay_lr: float = 1e-6
|
||||||
|
|
||||||
# ==================== Dataset Configuration ====================
|
|
||||||
# Dataset-specific normalization statistics
|
|
||||||
action_statistics: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Data configuration
|
|
||||||
data_config: dict = field(default_factory=lambda: {
|
|
||||||
"use_lerobot": True,
|
|
||||||
"lerobot_config": {
|
|
||||||
"repo_id": "",
|
|
||||||
"root": None,
|
|
||||||
"episodes": None,
|
|
||||||
"image_transforms": None,
|
|
||||||
"delta_timestamps": None,
|
|
||||||
"tolerance_s": 1e-4,
|
|
||||||
"revision": None,
|
|
||||||
"force_cache_sync": False,
|
|
||||||
"download_videos": True,
|
|
||||||
"video_backend": None,
|
|
||||||
},
|
|
||||||
"action_horizon": 32,
|
|
||||||
"train_test_split": 0.95,
|
|
||||||
"obs_action_keys": [],
|
|
||||||
"predict_action_keys": [],
|
|
||||||
"resolution": {
|
|
||||||
"face_view": 256,
|
|
||||||
"left_wrist_view": 256,
|
|
||||||
"right_wrist_view": 256,
|
|
||||||
"move1_view": 256,
|
|
||||||
"move2_view": 256,
|
|
||||||
"top_view": 256,
|
|
||||||
"wall_view": 256,
|
|
||||||
"multi_modal": 256,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
# ==================== Resume Configuration ====================
|
|
||||||
resume_config: dict | None = field(default_factory=lambda: None)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
@@ -239,243 +82,55 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.prediction_mode not in ["flow", "fast"]:
|
if self.prediction_mode not in ["diffusion", "fast"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prediction_mode must be 'flow' or 'fast', got {self.prediction_mode}"
|
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
||||||
)
|
|
||||||
|
|
||||||
# Validate dof_config total doesn't exceed max_action_dim
|
|
||||||
total_dof = sum(self.dof_config.values())
|
|
||||||
if total_dof > self.max_action_dim:
|
|
||||||
raise ValueError(
|
|
||||||
f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sync prediction_mode with use_fast_tokenizer
|
# Sync prediction_mode with use_fast_tokenizer
|
||||||
if self.use_fast_tokenizer:
|
if self.use_fast_tokenizer:
|
||||||
self.prediction_mode = "fast"
|
self.prediction_mode = "fast"
|
||||||
else:
|
else:
|
||||||
self.prediction_mode = "flow"
|
self.prediction_mode = "diffusion"
|
||||||
|
|
||||||
def get_train_config(self) -> dict:
|
def validate_features(self) -> None:
|
||||||
"""
|
"""Validate and set up input/output features."""
|
||||||
Extract the complete train_config dictionary matching the YAML training configuration format.
|
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||||
|
if not image_features:
|
||||||
|
raise ValueError(
|
||||||
|
"Wall-X policy requires at least one visual input feature. "
|
||||||
|
"No features of type FeatureType.VISUAL found in input_features."
|
||||||
|
)
|
||||||
|
|
||||||
This method constructs the full train_config from WallXConfig fields, suitable for
|
if "observation.state" not in self.input_features:
|
||||||
training scripts and Qwen2_5_VLMoEForAction.from_pretrained.
|
state_feature = PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
Returns:
|
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||||
dict: Complete training configuration matching YAML structure.
|
)
|
||||||
"""
|
self.input_features["observation.state"] = state_feature
|
||||||
# Build customized_robot_config
|
|
||||||
if self.enable_customized_robot_config and self.customized_robot_config:
|
|
||||||
customized_robot_config = {
|
|
||||||
"name": self.customized_robot_config.get("name", ""),
|
|
||||||
"customized_dof_config": self.customized_robot_config.get(
|
|
||||||
"customized_dof_config", self.dof_config
|
|
||||||
),
|
|
||||||
"customized_agent_pos_config": self.customized_robot_config.get(
|
|
||||||
"customized_agent_pos_config", self.agent_pos_config
|
|
||||||
),
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
customized_robot_config = {
|
state_shape = self.input_features["observation.state"].shape
|
||||||
"name": self.data_config.get("lerobot_config", {}).get("repo_id", ""),
|
state_dim = state_shape[0] if state_shape else 0
|
||||||
"customized_dof_config": self.dof_config,
|
if state_dim > self.max_state_dim:
|
||||||
"customized_agent_pos_config": self.agent_pos_config,
|
raise ValueError(
|
||||||
}
|
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
|
||||||
|
f"Either reduce state dimension or increase max_state_dim in config."
|
||||||
|
)
|
||||||
|
|
||||||
train_config = {
|
if "action" not in self.output_features:
|
||||||
# Model and paths configuration
|
action_feature = PolicyFeature(
|
||||||
"log_name": self.log_name,
|
type=FeatureType.ACTION,
|
||||||
"log_project": self.log_project,
|
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||||
"model_type": self.model_type,
|
)
|
||||||
"pretrained_wallx_path": self.pretrained_wallx_path,
|
self.output_features["action"] = action_feature
|
||||||
"save_path": self.save_path,
|
else:
|
||||||
"use_fast_tokenizer": self.use_fast_tokenizer,
|
action_shape = self.output_features["action"].shape
|
||||||
"action_tokenizer_path": self.action_tokenizer_path,
|
action_dim = action_shape[0] if action_shape else 0
|
||||||
|
if action_dim > self.max_action_dim:
|
||||||
# Profiling configuration
|
raise ValueError(
|
||||||
"profile": self.profile,
|
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
|
||||||
"profile_save_path": self.profile_save_path,
|
f"Either reduce action dimension or increase max_action_dim in config."
|
||||||
"profile_wait_iters": self.profile_wait_iters,
|
)
|
||||||
"profile_warmup_iters": self.profile_warmup_iters,
|
|
||||||
"profile_active_iters": self.profile_active_iters,
|
|
||||||
|
|
||||||
# Training hyperparameters
|
|
||||||
"num_warmup_steps": self.num_warmup_steps,
|
|
||||||
"num_training_steps": self.num_training_steps,
|
|
||||||
"learning_rate": self.learning_rate,
|
|
||||||
"min_lr": self.min_lr,
|
|
||||||
"num_epoch": self.num_epoch,
|
|
||||||
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
|
||||||
"batch_size_per_gpu": self.batch_size_per_gpu,
|
|
||||||
"padding_side": self.padding_side,
|
|
||||||
"epoch_save_interval": self.epoch_save_interval,
|
|
||||||
|
|
||||||
# Training optimization
|
|
||||||
"FSDP2": self.fsdp2,
|
|
||||||
"torch_compile": self.torch_compile,
|
|
||||||
|
|
||||||
# Robot configuration
|
|
||||||
"dof_config": self.dof_config,
|
|
||||||
"agent_pos_config": self.agent_pos_config,
|
|
||||||
|
|
||||||
# Normalization stats
|
|
||||||
"norm_stats_path": self.norm_stats_path,
|
|
||||||
|
|
||||||
# Customized robot config
|
|
||||||
"enable_customized_robot_config": self.enable_customized_robot_config,
|
|
||||||
"customized_robot_config": customized_robot_config,
|
|
||||||
|
|
||||||
# Resume configuration
|
|
||||||
"resume": self.resume_config,
|
|
||||||
|
|
||||||
# Data configuration
|
|
||||||
"data": self.data_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
return train_config
|
|
||||||
|
|
||||||
def get_dataload_config(self) -> dict:
|
|
||||||
"""
|
|
||||||
Extract data loading configuration from config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Data loading configuration for preprocessing.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"action_horizon": self.data_config.get("action_horizon", self.chunk_size),
|
|
||||||
"train_test_split": self.data_config.get("train_test_split", 0.95),
|
|
||||||
"split_seed": 42,
|
|
||||||
"predict_action_keys": self.data_config.get("predict_action_keys", []),
|
|
||||||
"obs_action_keys": self.data_config.get("obs_action_keys", []),
|
|
||||||
"resolution": self.data_config.get("resolution", {}),
|
|
||||||
"priority_order": None,
|
|
||||||
"max_length": self.tokenizer_max_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_lerobot_config(self) -> dict:
|
|
||||||
"""
|
|
||||||
Extract LeRobot dataset configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: LeRobot dataset configuration.
|
|
||||||
"""
|
|
||||||
return self.data_config.get("lerobot_config", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_yaml_dict(cls, yaml_dict: dict) -> "WallXConfig":
|
|
||||||
"""
|
|
||||||
Create a WallXConfig from a YAML configuration dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
yaml_dict: Dictionary loaded from YAML training config file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
WallXConfig instance with values from YAML.
|
|
||||||
"""
|
|
||||||
config_kwargs = {}
|
|
||||||
|
|
||||||
# Model and paths
|
|
||||||
if "log_name" in yaml_dict:
|
|
||||||
config_kwargs["log_name"] = yaml_dict["log_name"]
|
|
||||||
if "log_project" in yaml_dict:
|
|
||||||
config_kwargs["log_project"] = yaml_dict["log_project"]
|
|
||||||
if "model_type" in yaml_dict:
|
|
||||||
config_kwargs["model_type"] = yaml_dict["model_type"]
|
|
||||||
if "pretrained_wallx_path" in yaml_dict:
|
|
||||||
config_kwargs["pretrained_wallx_path"] = yaml_dict["pretrained_wallx_path"]
|
|
||||||
if "save_path" in yaml_dict:
|
|
||||||
config_kwargs["save_path"] = yaml_dict["save_path"]
|
|
||||||
if "use_fast_tokenizer" in yaml_dict:
|
|
||||||
config_kwargs["use_fast_tokenizer"] = yaml_dict["use_fast_tokenizer"]
|
|
||||||
if "action_tokenizer_path" in yaml_dict:
|
|
||||||
config_kwargs["action_tokenizer_path"] = yaml_dict["action_tokenizer_path"]
|
|
||||||
|
|
||||||
# Profiling
|
|
||||||
if "profile" in yaml_dict:
|
|
||||||
config_kwargs["profile"] = yaml_dict["profile"]
|
|
||||||
if "profile_save_path" in yaml_dict:
|
|
||||||
config_kwargs["profile_save_path"] = yaml_dict["profile_save_path"]
|
|
||||||
if "profile_wait_iters" in yaml_dict:
|
|
||||||
config_kwargs["profile_wait_iters"] = yaml_dict["profile_wait_iters"]
|
|
||||||
if "profile_warmup_iters" in yaml_dict:
|
|
||||||
config_kwargs["profile_warmup_iters"] = yaml_dict["profile_warmup_iters"]
|
|
||||||
if "profile_active_iters" in yaml_dict:
|
|
||||||
config_kwargs["profile_active_iters"] = yaml_dict["profile_active_iters"]
|
|
||||||
|
|
||||||
# Training hyperparameters
|
|
||||||
if "num_warmup_steps" in yaml_dict:
|
|
||||||
config_kwargs["num_warmup_steps"] = yaml_dict["num_warmup_steps"]
|
|
||||||
config_kwargs["scheduler_warmup_steps"] = yaml_dict["num_warmup_steps"]
|
|
||||||
if "num_training_steps" in yaml_dict:
|
|
||||||
config_kwargs["num_training_steps"] = yaml_dict["num_training_steps"]
|
|
||||||
config_kwargs["scheduler_decay_steps"] = yaml_dict["num_training_steps"]
|
|
||||||
if "learning_rate" in yaml_dict:
|
|
||||||
config_kwargs["learning_rate"] = yaml_dict["learning_rate"]
|
|
||||||
config_kwargs["optimizer_lr"] = yaml_dict["learning_rate"]
|
|
||||||
if "min_lr" in yaml_dict:
|
|
||||||
config_kwargs["min_lr"] = yaml_dict["min_lr"]
|
|
||||||
config_kwargs["scheduler_decay_lr"] = yaml_dict["min_lr"]
|
|
||||||
if "num_epoch" in yaml_dict:
|
|
||||||
config_kwargs["num_epoch"] = yaml_dict["num_epoch"]
|
|
||||||
if "gradient_accumulation_steps" in yaml_dict:
|
|
||||||
config_kwargs["gradient_accumulation_steps"] = yaml_dict["gradient_accumulation_steps"]
|
|
||||||
if "batch_size_per_gpu" in yaml_dict:
|
|
||||||
config_kwargs["batch_size_per_gpu"] = yaml_dict["batch_size_per_gpu"]
|
|
||||||
if "padding_side" in yaml_dict:
|
|
||||||
config_kwargs["padding_side"] = yaml_dict["padding_side"]
|
|
||||||
if "epoch_save_interval" in yaml_dict:
|
|
||||||
config_kwargs["epoch_save_interval"] = yaml_dict["epoch_save_interval"]
|
|
||||||
|
|
||||||
# Training optimization
|
|
||||||
if "FSDP2" in yaml_dict:
|
|
||||||
config_kwargs["fsdp2"] = yaml_dict["FSDP2"]
|
|
||||||
if "torch_compile" in yaml_dict:
|
|
||||||
config_kwargs["torch_compile"] = yaml_dict["torch_compile"]
|
|
||||||
|
|
||||||
# Robot configuration
|
|
||||||
if "dof_config" in yaml_dict:
|
|
||||||
config_kwargs["dof_config"] = yaml_dict["dof_config"]
|
|
||||||
if "agent_pos_config" in yaml_dict:
|
|
||||||
config_kwargs["agent_pos_config"] = yaml_dict["agent_pos_config"]
|
|
||||||
|
|
||||||
# Normalization stats
|
|
||||||
if "norm_stats_path" in yaml_dict:
|
|
||||||
config_kwargs["norm_stats_path"] = yaml_dict["norm_stats_path"]
|
|
||||||
|
|
||||||
# Customized robot config
|
|
||||||
if "enable_customized_robot_config" in yaml_dict:
|
|
||||||
config_kwargs["enable_customized_robot_config"] = yaml_dict["enable_customized_robot_config"]
|
|
||||||
if "customized_robot_config" in yaml_dict:
|
|
||||||
config_kwargs["customized_robot_config"] = yaml_dict["customized_robot_config"]
|
|
||||||
|
|
||||||
# Resume config
|
|
||||||
if "resume" in yaml_dict:
|
|
||||||
config_kwargs["resume_config"] = yaml_dict["resume"]
|
|
||||||
|
|
||||||
# Data configuration
|
|
||||||
if "data" in yaml_dict:
|
|
||||||
data = yaml_dict["data"]
|
|
||||||
data_config = {
|
|
||||||
"use_lerobot": data.get("use_lerobot", True),
|
|
||||||
"action_horizon": data.get("action_horizon", 32),
|
|
||||||
"train_test_split": data.get("train_test_split", 0.95),
|
|
||||||
"obs_action_keys": data.get("obs_action_keys", []),
|
|
||||||
"predict_action_keys": data.get("predict_action_keys", []),
|
|
||||||
"resolution": data.get("resolution", {}),
|
|
||||||
}
|
|
||||||
if "lerobot_config" in data:
|
|
||||||
data_config["lerobot_config"] = data["lerobot_config"]
|
|
||||||
config_kwargs["data_config"] = data_config
|
|
||||||
|
|
||||||
# Set chunk_size from action_horizon
|
|
||||||
if "action_horizon" in data:
|
|
||||||
config_kwargs["chunk_size"] = data["action_horizon"]
|
|
||||||
config_kwargs["n_action_steps"] = data["action_horizon"]
|
|
||||||
|
|
||||||
return cls(**config_kwargs)
|
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamWConfig:
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
return AdamWConfig(
|
return AdamWConfig(
|
||||||
@@ -496,7 +151,7 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list:
|
def observation_delta_indices(self) -> list:
|
||||||
return [0]
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
|
|||||||
@@ -16,15 +16,9 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Wall-X Constants and Configuration Data.
|
Wall-X Constants and Configuration Data.
|
||||||
|
|
||||||
Contains dataset names, key mappings, frequency mappings, and action statistics
|
|
||||||
for cross-embodiment robotic control.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION
|
||||||
|
|
||||||
# Add wall-x repo to path if available
|
|
||||||
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
|
|
||||||
|
|
||||||
CAMERA_NAME_MAPPING = {
|
CAMERA_NAME_MAPPING = {
|
||||||
"face_view": "front view",
|
"face_view": "front view",
|
||||||
@@ -35,3 +29,15 @@ CAMERA_NAME_MAPPING = {
|
|||||||
"wall_view": "wall view",
|
"wall_view": "wall view",
|
||||||
"top_view": "top view",
|
"top_view": "top view",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RESOLUTION = 256
|
||||||
|
|
||||||
|
# Parameters for preprocessing
|
||||||
|
MAX_PIXELS = 16384 * 28 * 28
|
||||||
|
MIN_PIXELS = 4 * 28 * 28
|
||||||
|
IMAGE_FACTOR = 28
|
||||||
|
PRIORITY_ORDER = None
|
||||||
|
GENERATE_SUBTASK_RATIO = 0.0
|
||||||
|
MODEL_TYPE = "qwen2_5"
|
||||||
|
|
||||||
|
TOKENIZER_MAX_LENGTH = 768
|
||||||
@@ -38,6 +38,7 @@ import builtins
|
|||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from os import PathLike
|
||||||
import sys
|
import sys
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -57,16 +58,8 @@ from torchdiffeq import odeint
|
|||||||
from transformers import AutoConfig, AutoProcessor
|
from transformers import AutoConfig, AutoProcessor
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.cache_utils import (
|
from transformers.cache_utils import (
|
||||||
Cache,
|
|
||||||
DynamicCache,
|
|
||||||
SlidingWindowCache,
|
|
||||||
StaticCache,
|
StaticCache,
|
||||||
)
|
)
|
||||||
from transformers.generation import GenerationMixin
|
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
|
||||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.utils import is_torchdynamo_compiling, logging
|
from transformers.utils import is_torchdynamo_compiling, logging
|
||||||
from transformers import AutoProcessor, BatchFeature
|
from transformers import AutoProcessor, BatchFeature
|
||||||
from qwen_vl_utils.vision_process import smart_resize
|
from qwen_vl_utils.vision_process import smart_resize
|
||||||
@@ -80,29 +73,18 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
|
|||||||
from lerobot.policies.wall_x.utils import *
|
from lerobot.policies.wall_x.utils import *
|
||||||
from lerobot.policies.wall_x.constant import *
|
from lerobot.policies.wall_x.constant import *
|
||||||
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
||||||
Qwen2RMSNorm,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||||
Qwen2_5_VLRotaryEmbedding,
|
|
||||||
Qwen2_5_VLPreTrainedModel,
|
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
|
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
|
||||||
Qwen2_5_VisionTransformerPretrainedModel,
|
Qwen2_5_VisionTransformerPretrainedModel,
|
||||||
Qwen2_5_VLDecoderLayer_with_MoE,
|
|
||||||
Qwen2_5_VLACausalLMOutputWithPast,
|
Qwen2_5_VLACausalLMOutputWithPast,
|
||||||
Qwen2_5_VLMoEModel,
|
Qwen2_5_VLMoEModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Add wall-x repo to path if available
|
|
||||||
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
|
|
||||||
if WALL_X_PATH.exists():
|
|
||||||
sys.path.insert(0, str(WALL_X_PATH))
|
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
"""Sinusoidal positional embedding for diffusion timesteps."""
|
"""Sinusoidal positional embedding for diffusion timesteps."""
|
||||||
@@ -129,7 +111,7 @@ class ActionHead(nn.Module):
|
|||||||
for action sequence prediction.
|
for action sequence prediction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: WallXConfig):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -138,10 +120,9 @@ class ActionHead(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
# Beta distribution for noise scheduling
|
# Beta distribution for noise scheduling
|
||||||
noise_config = config.noise_scheduler
|
self.beta_alpha = 1.5
|
||||||
self.beta_alpha = noise_config.get("beta_alpha", 1.5)
|
self.beta_beta = 1.0
|
||||||
self.beta_beta = noise_config.get("beta_beta", 1.0)
|
self.s = 0.999
|
||||||
self.s = noise_config.get("s", 0.999)
|
|
||||||
|
|
||||||
# Sinusoidal timestep embedding
|
# Sinusoidal timestep embedding
|
||||||
self.time_embed = SinusoidalPosEmb(config.hidden_size)
|
self.time_embed = SinusoidalPosEmb(config.hidden_size)
|
||||||
@@ -277,12 +258,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
pretrained_model_path,
|
pretrained_name_or_path,
|
||||||
train_config,
|
config=None,
|
||||||
config_path=None,
|
|
||||||
processor_path=None,
|
|
||||||
action_tokenizer_path=None,
|
action_tokenizer_path=None,
|
||||||
**kwargs,
|
cache_dir: str | PathLike | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
revision: str = "main",
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load model from pretrained model path.
|
Load model from pretrained model path.
|
||||||
@@ -290,17 +275,24 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
Args:
|
Args:
|
||||||
pretrained_model_path (str): Model directory path containing model.safetensors file
|
pretrained_model_path (str): Model directory path containing model.safetensors file
|
||||||
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
|
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
|
||||||
processor_path (str, optional): Processor path, if None will load from default config
|
|
||||||
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
|
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
|
||||||
**kwargs: Additional arguments
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Qwen2_5_VLMoEForAction: Loaded model instance
|
Qwen2_5_VLMoEForAction: Loaded model instance
|
||||||
"""
|
"""
|
||||||
# Load model components from pretrained path
|
if config is None:
|
||||||
config_path = os.path.join(pretrained_model_path, "config.json")
|
config = cls.config_class.from_pretrained(
|
||||||
config = cls.config_class.from_pretrained(config_path)
|
pretrained_name_or_path,
|
||||||
processor = AutoProcessor.from_pretrained(pretrained_model_path, use_fast=True)
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
strict=strict,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
|
||||||
if action_tokenizer_path is not None:
|
if action_tokenizer_path is not None:
|
||||||
processor.action_processor = AutoProcessor.from_pretrained(
|
processor.action_processor = AutoProcessor.from_pretrained(
|
||||||
action_tokenizer_path, trust_remote_code=True
|
action_tokenizer_path, trust_remote_code=True
|
||||||
@@ -312,22 +304,41 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
# Resize token embeddings to match processor tokenizer vocabulary size
|
# Resize token embeddings to match processor tokenizer vocabulary size
|
||||||
model.resize_token_embeddings(len(processor.tokenizer))
|
model.resize_token_embeddings(len(processor.tokenizer))
|
||||||
|
|
||||||
# Load model state dict from safetensors file
|
# Try to load the model.safetensors file
|
||||||
safetensor_files = glob.glob(
|
print(f"Loading model from: {pretrained_name_or_path}")
|
||||||
os.path.join(pretrained_model_path, "*.safetensors")
|
try:
|
||||||
)
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
|
# Try safetensors first
|
||||||
|
resolved_file = cached_file(
|
||||||
|
pretrained_name_or_path,
|
||||||
|
"model.safetensors",
|
||||||
|
cache_dir=kwargs.get("cache_dir"),
|
||||||
|
force_download=kwargs.get("force_download", False),
|
||||||
|
resume_download=kwargs.get("resume_download"),
|
||||||
|
proxies=kwargs.get("proxies"),
|
||||||
|
use_auth_token=kwargs.get("use_auth_token"),
|
||||||
|
revision=kwargs.get("revision"),
|
||||||
|
local_files_only=kwargs.get("local_files_only", False),
|
||||||
|
)
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
sd = load_file(resolved_file)
|
||||||
|
print("✓ Loaded state dict from model.safetensors")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not load state dict from remote files: {e}")
|
||||||
|
print("Returning model without loading pretrained weights")
|
||||||
|
return model
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
for file in safetensor_files:
|
# filter normalizer statistic params
|
||||||
sd = load_file(file, device="cpu")
|
del_keys = []
|
||||||
# filter normalizer statistic params
|
for key in sd.keys():
|
||||||
del_keys = []
|
if "action_preprocessor.normalizer" in key:
|
||||||
for key in sd.keys():
|
del_keys.append(key)
|
||||||
if "action_preprocessor.normalizer" in key:
|
for key in del_keys:
|
||||||
print(f"filter load model weight {key}")
|
del sd[key]
|
||||||
del_keys.append(key)
|
state_dict.update(sd)
|
||||||
for key in del_keys:
|
|
||||||
del sd[key]
|
|
||||||
state_dict.update(sd)
|
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
@@ -730,7 +741,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
dataset_names: Optional[str] = None,
|
|
||||||
dof_mask: Optional[torch.FloatTensor] = None,
|
dof_mask: Optional[torch.FloatTensor] = None,
|
||||||
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -763,7 +773,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
rope_deltas (torch.LongTensor, optional): RoPE position deltas
|
rope_deltas (torch.LongTensor, optional): RoPE position deltas
|
||||||
cache_position (torch.LongTensor, optional): Cache position indices
|
cache_position (torch.LongTensor, optional): Cache position indices
|
||||||
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
|
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
|
||||||
dataset_names (str, optional): Names of datasets in the current batch
|
|
||||||
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens
|
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens
|
||||||
agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data
|
agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data
|
||||||
**kwargs: Additional keyword arguments
|
**kwargs: Additional keyword arguments
|
||||||
@@ -872,7 +881,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
proprioception = self.action_preprocessor.proprioception_proj(
|
proprioception = self.action_preprocessor.proprioception_proj(
|
||||||
proprioception,
|
proprioception,
|
||||||
dataset_names,
|
|
||||||
agent_pos_mask,
|
agent_pos_mask,
|
||||||
use_history=proprioception.shape[1] > 1,
|
use_history=proprioception.shape[1] > 1,
|
||||||
)
|
)
|
||||||
@@ -909,7 +917,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||||
noisy_action_emb, flow = self.action_preprocessor(
|
noisy_action_emb, flow = self.action_preprocessor(
|
||||||
action_chunk, dataset_names, dof_mask
|
action_chunk, dof_mask
|
||||||
)
|
)
|
||||||
mask = input_ids == self.action_token_id_set["action_token_id"]
|
mask = input_ids == self.action_token_id_set["action_token_id"]
|
||||||
mask_unsqueezed = mask.unsqueeze(-1)
|
mask_unsqueezed = mask.unsqueeze(-1)
|
||||||
@@ -953,18 +961,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
# Compute losses if labels are provided
|
# Compute losses if labels are provided
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = 0
|
loss = 0
|
||||||
action_accuracy = 0
|
|
||||||
unique_datasets_name = list(set(dataset_names))
|
|
||||||
|
|
||||||
# Initialize per-dataset loss tracking dictionaries
|
|
||||||
channel_loss_dict = {
|
|
||||||
dataset_name: torch.tensor(0.0, device=logits.device)
|
|
||||||
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
|
|
||||||
}
|
|
||||||
channel_loss_count_dict = {
|
|
||||||
dataset_name: torch.tensor(0, device=logits.device)
|
|
||||||
for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES
|
|
||||||
}
|
|
||||||
|
|
||||||
# Compute standard cross-entropy loss for language modeling
|
# Compute standard cross-entropy loss for language modeling
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
@@ -982,22 +978,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
else torch.tensor(0.0, device=shift_logits.device)
|
else torch.tensor(0.0, device=shift_logits.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute per-dataset channel losses
|
|
||||||
_cross_entropy_loss = _cross_entropy_loss.view(batch_size, seq_length - 1)
|
|
||||||
non_ignored_mask = non_ignored_mask.view(batch_size, seq_length - 1)
|
|
||||||
for dataset_name_i in unique_datasets_name:
|
|
||||||
dataset_mask = torch.tensor(
|
|
||||||
[name == dataset_name_i for name in dataset_names],
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
combined_mask = dataset_mask.unsqueeze(1) & non_ignored_mask
|
|
||||||
channel_loss_dict[dataset_name_i] = (
|
|
||||||
_cross_entropy_loss[combined_mask].sum()
|
|
||||||
if combined_mask.any()
|
|
||||||
else torch.tensor(0.0, device=shift_logits.device)
|
|
||||||
)
|
|
||||||
channel_loss_count_dict[dataset_name_i] += combined_mask.sum()
|
|
||||||
|
|
||||||
# Add cross-entropy loss to total loss if valid
|
# Add cross-entropy loss to total loss if valid
|
||||||
if not torch.isnan(cross_entropy_loss):
|
if not torch.isnan(cross_entropy_loss):
|
||||||
loss += cross_entropy_loss
|
loss += cross_entropy_loss
|
||||||
@@ -1005,20 +985,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
cross_entropy_loss.detach()
|
cross_entropy_loss.detach()
|
||||||
|
|
||||||
# Compute action token prediction accuracy
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
action_preds = shift_logits.argmax(dim=-1)
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
if self.use_fast_tokenizer:
|
|
||||||
action_mask = (
|
|
||||||
shift_labels > self.action_token_id_set["fast_action_token_list"][0]
|
|
||||||
)
|
|
||||||
correct_preds = (action_preds == shift_labels) & action_mask
|
|
||||||
action_accuracy = (
|
|
||||||
correct_preds.sum().float() / action_mask.sum().float()
|
|
||||||
)
|
|
||||||
channel_loss_dict["action_accuracy"] = action_accuracy
|
|
||||||
|
|
||||||
if action_chunk is not None:
|
if action_chunk is not None:
|
||||||
action_mask = input_ids == self.action_token_id_set["action_token_id"]
|
action_mask = input_ids == self.action_token_id_set["action_token_id"]
|
||||||
if action_mask.any():
|
if action_mask.any():
|
||||||
@@ -1101,7 +1067,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
num_inference_timesteps: Optional[int] = 10,
|
num_inference_timesteps: Optional[int] = 10,
|
||||||
dataset_names: Optional[str] = None,
|
|
||||||
dof_mask: Optional[torch.FloatTensor] = None,
|
dof_mask: Optional[torch.FloatTensor] = None,
|
||||||
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
||||||
re_generate: bool = False,
|
re_generate: bool = False,
|
||||||
@@ -1140,7 +1105,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
cache_position (torch.LongTensor, optional): Cache position indices
|
cache_position (torch.LongTensor, optional): Cache position indices
|
||||||
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
|
second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid
|
||||||
num_inference_timesteps (int, optional): Number of diffusion inference steps
|
num_inference_timesteps (int, optional): Number of diffusion inference steps
|
||||||
dataset_names (str, optional): Dataset names for normalization
|
|
||||||
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask
|
dof_mask (torch.FloatTensor, optional): Degrees of freedom mask
|
||||||
agent_pos_mask (torch.FloatTensor, optional): Agent position mask
|
agent_pos_mask (torch.FloatTensor, optional): Agent position mask
|
||||||
re_generate (bool, optional): Whether to use sampling for regeneration
|
re_generate (bool, optional): Whether to use sampling for regeneration
|
||||||
@@ -1239,7 +1203,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
proprio_embed = self.action_preprocessor.proprioception_proj(
|
proprio_embed = self.action_preprocessor.proprioception_proj(
|
||||||
proprioception,
|
proprioception,
|
||||||
dataset_names,
|
|
||||||
agent_pos_mask,
|
agent_pos_mask,
|
||||||
use_history=proprioception.shape[1] > 1,
|
use_history=proprioception.shape[1] > 1,
|
||||||
)
|
)
|
||||||
@@ -1339,7 +1302,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
"dof_mask": dof_mask,
|
"dof_mask": dof_mask,
|
||||||
"agent_pos_mask": agent_pos_mask,
|
"agent_pos_mask": agent_pos_mask,
|
||||||
"proprioception": proprioception,
|
"proprioception": proprioception,
|
||||||
"dataset_names": dataset_names,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate output tokens
|
# Generate output tokens
|
||||||
@@ -1545,7 +1507,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
image_grid_thw=None,
|
image_grid_thw=None,
|
||||||
video_grid_thw=None,
|
video_grid_thw=None,
|
||||||
second_per_grid_ts=None,
|
second_per_grid_ts=None,
|
||||||
dataset_names=None,
|
|
||||||
proprioception=None,
|
proprioception=None,
|
||||||
dof_mask=None,
|
dof_mask=None,
|
||||||
agent_pos_mask=None,
|
agent_pos_mask=None,
|
||||||
@@ -1572,7 +1533,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
image_grid_thw: Image grid dimensions
|
image_grid_thw: Image grid dimensions
|
||||||
video_grid_thw: Video grid dimensions
|
video_grid_thw: Video grid dimensions
|
||||||
second_per_grid_ts: Time interval per temporal grid
|
second_per_grid_ts: Time interval per temporal grid
|
||||||
dataset_names: Dataset names for processing
|
|
||||||
proprioception: Proprioceptive sensor data
|
proprioception: Proprioceptive sensor data
|
||||||
dof_mask: Degrees of freedom mask
|
dof_mask: Degrees of freedom mask
|
||||||
agent_pos_mask: Agent position mask
|
agent_pos_mask: Agent position mask
|
||||||
@@ -1677,7 +1637,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
"second_per_grid_ts": second_per_grid_ts,
|
"second_per_grid_ts": second_per_grid_ts,
|
||||||
"proprioception": proprioception,
|
"proprioception": proprioception,
|
||||||
"dataset_names": dataset_names,
|
|
||||||
"dof_mask": dof_mask,
|
"dof_mask": dof_mask,
|
||||||
"agent_pos_mask": agent_pos_mask,
|
"agent_pos_mask": agent_pos_mask,
|
||||||
}
|
}
|
||||||
@@ -1866,150 +1825,14 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
config.validate_features()
|
config.validate_features()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Initialize VLM wrapper
|
# Initialize the wall-x model
|
||||||
self.model = Qwen2_5_VLMoEForAction(config)
|
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
|
||||||
|
self.model.to(config.device)
|
||||||
|
# Convert to bfloat16 for Flash Attention compatibility
|
||||||
|
self.model.to_bfloat16_for_selected_params()
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(
|
|
||||||
cls: builtins.type[T],
|
|
||||||
pretrained_name_or_path: str | Path,
|
|
||||||
*,
|
|
||||||
config: PreTrainedConfig | None = None,
|
|
||||||
force_download: bool = False,
|
|
||||||
resume_download: bool | None = None,
|
|
||||||
proxies: dict | None = None,
|
|
||||||
token: str | bool | None = None,
|
|
||||||
cache_dir: str | Path | None = None,
|
|
||||||
local_files_only: bool = False,
|
|
||||||
revision: str | None = None,
|
|
||||||
strict: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> T:
|
|
||||||
"""
|
|
||||||
Load WallXPolicy from a pretrained model path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained_name_or_path: Path to pretrained model or model identifier
|
|
||||||
config: Optional configuration object
|
|
||||||
force_download: Force download even if cached
|
|
||||||
resume_download: Resume interrupted download
|
|
||||||
proxies: Proxy configuration
|
|
||||||
token: Authentication token
|
|
||||||
cache_dir: Cache directory path
|
|
||||||
local_files_only: Only use local files
|
|
||||||
revision: Model revision
|
|
||||||
strict: Strict loading of state dict
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
WallXPolicy: Loaded policy instance
|
|
||||||
"""
|
|
||||||
print(
|
|
||||||
"Loading Wall-X model for cross-embodiment robotic control.\n"
|
|
||||||
"This implementation integrates Qwen2.5-VL with flow matching for action prediction."
|
|
||||||
)
|
|
||||||
|
|
||||||
if pretrained_name_or_path is None:
|
|
||||||
raise ValueError("pretrained_name_or_path is required")
|
|
||||||
|
|
||||||
# Use provided config if available, otherwise load from pretrained path
|
|
||||||
if config is None:
|
|
||||||
config = PreTrainedConfig.from_pretrained(
|
|
||||||
pretrained_name_or_path=pretrained_name_or_path,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
token=token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
revision=revision,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize model without loading weights
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
|
|
||||||
# Load and remap the state dict
|
|
||||||
try:
|
|
||||||
print(f"Loading model from: {pretrained_name_or_path}")
|
|
||||||
try:
|
|
||||||
from transformers.utils import cached_file
|
|
||||||
|
|
||||||
# Try safetensors first
|
|
||||||
resolved_file = cached_file(
|
|
||||||
pretrained_name_or_path,
|
|
||||||
"model.safetensors",
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
original_state_dict = load_file(resolved_file)
|
|
||||||
print("✓ Loaded state dict from model.safetensors")
|
|
||||||
except Exception:
|
|
||||||
print(f"Could not load state dict: {e}")
|
|
||||||
print("Returning model without loading pretrained weights")
|
|
||||||
return model
|
|
||||||
|
|
||||||
# Filter out normalizer statistics if present
|
|
||||||
filtered_state_dict = {}
|
|
||||||
for key, value in original_state_dict.items():
|
|
||||||
if "action_preprocessor.normalizer" not in key:
|
|
||||||
filtered_state_dict[key] = value
|
|
||||||
else:
|
|
||||||
print(f"Filtered key: {key}")
|
|
||||||
|
|
||||||
# Add "model." prefix for keys that don't have it
|
|
||||||
remapped_state_dict = {}
|
|
||||||
remap_count = 0
|
|
||||||
|
|
||||||
for key, value in filtered_state_dict.items():
|
|
||||||
if not key.startswith("model."):
|
|
||||||
new_key = f"model.{key}"
|
|
||||||
remapped_state_dict[new_key] = value
|
|
||||||
remap_count += 1
|
|
||||||
else:
|
|
||||||
remapped_state_dict[key] = value
|
|
||||||
|
|
||||||
if remap_count > 0:
|
|
||||||
print(f"Remapped {remap_count} state dict keys")
|
|
||||||
|
|
||||||
# Load the remapped state dict into the model
|
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
|
||||||
|
|
||||||
if missing_keys:
|
|
||||||
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
|
||||||
if len(missing_keys) <= 5:
|
|
||||||
for key in missing_keys:
|
|
||||||
print(f" - {key}")
|
|
||||||
else:
|
|
||||||
for key in missing_keys[:5]:
|
|
||||||
print(f" - {key}")
|
|
||||||
print(f" ... and {len(missing_keys) - 5} more")
|
|
||||||
|
|
||||||
if unexpected_keys:
|
|
||||||
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
|
||||||
if len(unexpected_keys) <= 5:
|
|
||||||
for key in unexpected_keys:
|
|
||||||
print(f" - {key}")
|
|
||||||
else:
|
|
||||||
for key in unexpected_keys[:5]:
|
|
||||||
print(f" - {key}")
|
|
||||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
|
||||||
|
|
||||||
if not missing_keys and not unexpected_keys:
|
|
||||||
print("All keys loaded successfully!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not load state dict: {e}")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset action queue."""
|
"""Reset action queue."""
|
||||||
self._queues = {
|
self._queues = {
|
||||||
@@ -2022,88 +1845,50 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def preprocess_inputs(
|
def preprocess_inputs(
|
||||||
self,
|
self,
|
||||||
batch: List[Dict[str, Any]],
|
batch: Dict[str, Any],
|
||||||
config: Dict[str, Any],
|
|
||||||
dataload_config: Dict[str, Any],
|
|
||||||
lerobot_config: Dict[str, Any],
|
|
||||||
processor: Any,
|
|
||||||
action_tokenizer: Optional[Any] = None,
|
|
||||||
camera_keys: Optional[List[str]] = None,
|
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
Convert a batch of LeRobot dataset items to Wall-X model input format.
|
Convert a batch of LeRobot dataset items to Wall-X model input format.
|
||||||
|
|
||||||
This is the batch version of convert_lerobot_to_wallx_format, processing multiple
|
This processes a batched dictionary where tensors have batch dimension first.
|
||||||
samples together for efficient batched inference or training.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: List of items from LeRobot dataset
|
batch: Dictionary with batched tensors:
|
||||||
config: Model configuration dict
|
- "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim)
|
||||||
dataload_config: Data loading configuration
|
- "action": (batch_size, chunk_size, action_dim)
|
||||||
norm_stats: Normalization statistics
|
- "observation.images.<key>": (batch_size, C, H, W)
|
||||||
lerobot_config: LeRobot config containing 'repo_id'
|
- "task": List[str] of length batch_size
|
||||||
processor: Hugging Face processor for tokenization
|
|
||||||
action_tokenizer: Optional action tokenizer
|
|
||||||
camera_keys: List of camera keys in the dataset
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BatchFeature containing batched model inputs
|
BatchFeature containing batched model inputs
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> batch = [dataset[i] for i in range(8)]
|
|
||||||
>>> result = preprocess_inputs(
|
|
||||||
... batch=batch,
|
|
||||||
... config=config,
|
|
||||||
... dataload_config=dataload_config,
|
|
||||||
... norm_stats=norm_stats,
|
|
||||||
... lerobot_config={'repo_id': 'lerobot/aloha_mobile_cabinet'},
|
|
||||||
... processor=processor,
|
|
||||||
... )
|
|
||||||
"""
|
"""
|
||||||
repo_id = lerobot_config["repo_id"]
|
use_fast_tokenizer = self.config.use_fast_tokenizer
|
||||||
use_fast_tokenizer = config.get("use_fast_tokenizer", False)
|
|
||||||
|
|
||||||
# Get key mappings
|
# Get batch size from state tensor
|
||||||
cam_key_mapping = KEY_MAPPINGS[repo_id]["camera"]
|
batch_size = batch[OBS_STATE].shape[0]
|
||||||
state_key = KEY_MAPPINGS[repo_id]["state"]
|
|
||||||
action_key = KEY_MAPPINGS[repo_id]["action"]
|
|
||||||
|
|
||||||
# Build data config
|
|
||||||
data_config = X2RDataProcessingConfig().update(
|
|
||||||
train_test_split=dataload_config.get("train_test_split", 0.95),
|
|
||||||
split_seed=dataload_config.get("split_seed", 42),
|
|
||||||
predict_action_keys=dataload_config.get("predict_action_keys", []),
|
|
||||||
obs_action_keys=dataload_config.get("obs_action_keys", []),
|
|
||||||
resolution=dataload_config.get("resolution", None),
|
|
||||||
priority_order=dataload_config.get("priority_order", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
if camera_keys is None:
|
|
||||||
camera_keys = list(cam_key_mapping.keys())
|
|
||||||
|
|
||||||
# ==================== PROCESS ALL SAMPLES ====================
|
# ==================== PROCESS ALL SAMPLES ====================
|
||||||
all_image_inputs = []
|
all_image_inputs = []
|
||||||
all_texts = []
|
all_texts = []
|
||||||
all_agent_pos = []
|
|
||||||
all_actions = []
|
|
||||||
all_frame_indices = []
|
|
||||||
|
|
||||||
for data in batch:
|
# Find image keys in batch
|
||||||
|
img_keys = [key for key in self.config.image_features if key in batch]
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
# Vision preprocessing per sample
|
# Vision preprocessing per sample
|
||||||
processed_frames = []
|
processed_frames = []
|
||||||
orig_height, orig_width = None, None
|
orig_height, orig_width = None, None
|
||||||
resized_height, resized_width = None, None
|
resized_height, resized_width = None, None
|
||||||
|
|
||||||
for key in camera_keys:
|
for key in img_keys:
|
||||||
current_obs = data[key].clone()
|
current_obs = batch[key][i].clone() # (C, H, W)
|
||||||
if current_obs.dim() == 3:
|
if current_obs.dim() == 3:
|
||||||
current_obs = current_obs.permute(1, 2, 0)
|
current_obs = current_obs.permute(1, 2, 0) # (H, W, C)
|
||||||
|
|
||||||
img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy())
|
img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy())
|
||||||
orig_width, orig_height = img_pil.size
|
orig_width, orig_height = img_pil.size
|
||||||
|
|
||||||
cam_name = cam_key_mapping.get(key, key)
|
target_size = RESOLUTION
|
||||||
target_size = data_config.resolution.get(cam_name, -1)
|
|
||||||
if target_size != -1:
|
if target_size != -1:
|
||||||
if orig_width > orig_height:
|
if orig_width > orig_height:
|
||||||
new_width = target_size
|
new_width = target_size
|
||||||
@@ -2117,9 +1902,9 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
resized_height, resized_width = smart_resize(
|
resized_height, resized_width = smart_resize(
|
||||||
current_height,
|
current_height,
|
||||||
current_width,
|
current_width,
|
||||||
factor=data_config.image_factor,
|
factor=IMAGE_FACTOR,
|
||||||
min_pixels=data_config.min_pixels,
|
min_pixels=MIN_PIXELS,
|
||||||
max_pixels=data_config.max_pixels,
|
max_pixels=MAX_PIXELS,
|
||||||
)
|
)
|
||||||
resized_img = img_pil.resize((resized_width, resized_height))
|
resized_img = img_pil.resize((resized_width, resized_height))
|
||||||
processed_frames.append(resized_img)
|
processed_frames.append(resized_img)
|
||||||
@@ -2127,33 +1912,30 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
all_image_inputs.append(processed_frames)
|
all_image_inputs.append(processed_frames)
|
||||||
|
|
||||||
# Text preprocessing
|
# Text preprocessing
|
||||||
frame_index = data["frame_index"]
|
task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"]
|
||||||
instruction_info = {"instruction": data["task"]}
|
instruction_info = {"instruction": task_text}
|
||||||
|
|
||||||
|
frame_index = batch["frame_index"][i] if "frame_index" in batch else 0
|
||||||
complete_text, _ = get_wallx_normal_text(
|
complete_text, _ = get_wallx_normal_text(
|
||||||
instruction_info,
|
instruction_info,
|
||||||
dataload_config.get("action_horizon", 33) - 1,
|
self.config.chunk_size,
|
||||||
frame_index,
|
frame_index,
|
||||||
data_config.priority_order,
|
PRIORITY_ORDER,
|
||||||
cam_key_mapping,
|
img_keys,
|
||||||
generate_subtask_ratio=data_config.generate_subtask_ratio,
|
generate_subtask_ratio=GENERATE_SUBTASK_RATIO,
|
||||||
)
|
)
|
||||||
|
|
||||||
text = process_grounding_points(
|
text = process_grounding_points(
|
||||||
complete_text, orig_height, orig_width, resized_height, resized_width,
|
complete_text, orig_height, orig_width, resized_height, resized_width,
|
||||||
data_config.model_type
|
MODEL_TYPE
|
||||||
)
|
)
|
||||||
all_texts.append(text)
|
all_texts.append(text)
|
||||||
|
|
||||||
# Collect raw values
|
|
||||||
all_agent_pos.append(data[state_key])
|
|
||||||
all_actions.append(data[action_key])
|
|
||||||
all_frame_indices.append(frame_index)
|
|
||||||
|
|
||||||
# Stack agent_pos
|
# ==================== PROCESS AGENT POS ====================
|
||||||
agent_pos = torch.stack(all_agent_pos)
|
agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
|
||||||
if agent_pos.dim() == 2:
|
if agent_pos.dim() == 2:
|
||||||
agent_pos = agent_pos.unsqueeze(1)
|
agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim)
|
||||||
agent_pos_mask = (~torch.isnan(agent_pos)).float()
|
agent_pos_mask = (~torch.isnan(agent_pos)).float()
|
||||||
agent_pos = agent_pos.nan_to_num(nan=0.0)
|
agent_pos = agent_pos.nan_to_num(nan=0.0)
|
||||||
|
|
||||||
@@ -2161,15 +1943,15 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
pad_size = 20 - agent_pos.shape[-1]
|
pad_size = 20 - agent_pos.shape[-1]
|
||||||
agent_pos = torch.cat([
|
agent_pos = torch.cat([
|
||||||
agent_pos,
|
agent_pos,
|
||||||
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size)
|
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device)
|
||||||
], dim=-1)
|
], dim=-1)
|
||||||
agent_pos_mask = torch.cat([
|
agent_pos_mask = torch.cat([
|
||||||
agent_pos_mask,
|
agent_pos_mask,
|
||||||
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size)
|
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device)
|
||||||
], dim=-1)
|
], dim=-1)
|
||||||
|
|
||||||
# Stack actions
|
# ==================== PROCESS ACTIONS ====================
|
||||||
action = torch.stack(all_actions)
|
action = batch[ACTION] # (batch_size, chunk_size, action_dim)
|
||||||
if action.dim() == 2:
|
if action.dim() == 2:
|
||||||
action = action.unsqueeze(1)
|
action = action.unsqueeze(1)
|
||||||
dof_mask = (~torch.isnan(action)).float()
|
dof_mask = (~torch.isnan(action)).float()
|
||||||
@@ -2179,36 +1961,35 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
pad_size = 20 - action.shape[-1]
|
pad_size = 20 - action.shape[-1]
|
||||||
action = torch.cat([
|
action = torch.cat([
|
||||||
action,
|
action,
|
||||||
torch.zeros(action.shape[0], action.shape[1], pad_size)
|
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
|
||||||
], dim=-1)
|
], dim=-1)
|
||||||
dof_mask = torch.cat([
|
dof_mask = torch.cat([
|
||||||
dof_mask,
|
dof_mask,
|
||||||
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size)
|
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
|
||||||
], dim=-1)
|
], dim=-1)
|
||||||
|
|
||||||
# ==================== ACTION TOKEN REPLACEMENT ====================
|
# ==================== ACTION TOKEN REPLACEMENT ====================
|
||||||
all_texts = replace_action_token(
|
all_texts = replace_action_token(
|
||||||
all_texts,
|
all_texts,
|
||||||
action,
|
action,
|
||||||
action_tokenizer if use_fast_tokenizer else None,
|
self.model.action_tokenizer if use_fast_tokenizer else None,
|
||||||
[repo_id] * len(batch),
|
|
||||||
dof_mask,
|
dof_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== TOKENIZATION ====================
|
# ==================== TOKENIZATION ====================
|
||||||
inputs = preprocesser_call(
|
inputs = preprocesser_call(
|
||||||
processor=processor,
|
processor=self.model.processor,
|
||||||
text=all_texts,
|
text=all_texts,
|
||||||
images=all_image_inputs,
|
images=all_image_inputs,
|
||||||
videos=None,
|
videos=None,
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
max_length=dataload_config.get("max_length", 768),
|
max_length=TOKENIZER_MAX_LENGTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== ADDITIONAL INPUTS ====================
|
# ==================== ADDITIONAL INPUTS ====================
|
||||||
action_token_id = processor.tokenizer.convert_tokens_to_ids("<|action|>")
|
action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>")
|
||||||
moe_token_types = inputs.input_ids == action_token_id
|
moe_token_types = inputs.input_ids == action_token_id
|
||||||
|
|
||||||
inputs["proprioception"] = agent_pos
|
inputs["proprioception"] = agent_pos
|
||||||
@@ -2216,11 +1997,13 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
inputs["action_chunk"] = action
|
inputs["action_chunk"] = action
|
||||||
inputs["dof_mask"] = dof_mask
|
inputs["dof_mask"] = dof_mask
|
||||||
inputs["moe_token_types"] = moe_token_types
|
inputs["moe_token_types"] = moe_token_types
|
||||||
inputs["dataset_names"] = [repo_id] * len(batch)
|
inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device)
|
||||||
inputs["frame_index"] = torch.stack([
|
|
||||||
torch.tensor(fi) if not isinstance(fi, torch.Tensor) else fi
|
# Move all tensors to the correct device
|
||||||
for fi in all_frame_indices
|
device = self.config.device
|
||||||
])
|
for key, value in inputs.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
inputs[key] = value.to(device)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@@ -2232,37 +2015,19 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
batch: Dictionary containing preprocessed inputs from preprocess_inputs()
|
batch: Dictionary containing preprocessed inputs from preprocess_inputs()
|
||||||
Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw,
|
Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw,
|
||||||
proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types,
|
proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types,
|
||||||
dataset_names, etc.
|
etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (loss, loss_dict)
|
tuple: (loss, loss_dict)
|
||||||
"""
|
"""
|
||||||
batch = self.preprocess_inputs(
|
batch = self.preprocess_inputs(
|
||||||
batch,
|
batch,
|
||||||
self.config,
|
|
||||||
self.dataload_config,
|
|
||||||
self.lerobot_config,
|
|
||||||
self.processor,
|
|
||||||
self.action_tokenizer,
|
|
||||||
self.camera_keys
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the underlying model's forward with mode="train"
|
# Call the underlying model's forward with mode="train"
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
mode="train",
|
**batch,
|
||||||
input_ids=batch.get("input_ids"),
|
mode="train"
|
||||||
attention_mask=batch.get("attention_mask"),
|
|
||||||
pixel_values=batch.get("pixel_values"),
|
|
||||||
image_grid_thw=batch.get("image_grid_thw"),
|
|
||||||
pixel_values_videos=batch.get("pixel_values_videos"),
|
|
||||||
video_grid_thw=batch.get("video_grid_thw"),
|
|
||||||
proprioception=batch.get("proprioception"),
|
|
||||||
agent_pos_mask=batch.get("agent_pos_mask"),
|
|
||||||
action_chunk=batch.get("action_chunk"),
|
|
||||||
dof_mask=batch.get("dof_mask"),
|
|
||||||
moe_token_types=batch.get("moe_token_types"),
|
|
||||||
dataset_names=batch.get("dataset_names"),
|
|
||||||
labels=batch.get("labels", batch.get("input_ids")), # Use input_ids as labels if not provided
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract losses from output
|
# Extract losses from output
|
||||||
@@ -2292,16 +2057,10 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
batch = self.preprocess_inputs(
|
batch = self.preprocess_inputs(
|
||||||
batch,
|
batch,
|
||||||
self.config,
|
|
||||||
self.dataload_config,
|
|
||||||
self.lerobot_config,
|
|
||||||
self.processor,
|
|
||||||
self.action_tokenizer,
|
|
||||||
self.camera_keys
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.prediction_mode == "diffusion":
|
if self.config.prediction_mode == "diffusion":
|
||||||
actions = self.model(
|
output = self.model(
|
||||||
**batch,
|
**batch,
|
||||||
action_dim=self.config.max_action_dim,
|
action_dim=self.config.max_action_dim,
|
||||||
pred_horizon=self.config.chunk_size,
|
pred_horizon=self.config.chunk_size,
|
||||||
@@ -2309,9 +2068,9 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
predict_mode="diffusion"
|
predict_mode="diffusion"
|
||||||
)
|
)
|
||||||
elif self.config.prediction_mode == "fast":
|
elif self.config.prediction_mode == "fast":
|
||||||
actions = self.model(
|
output = self.model(
|
||||||
**batch,
|
**batch,
|
||||||
action_dim=self.config.action_feature.shape[0],
|
action_dim=self.config.output_features["action"].shape[0],
|
||||||
pred_horizon=self.config.chunk_size,
|
pred_horizon=self.config.chunk_size,
|
||||||
mode="predict",
|
mode="predict",
|
||||||
predict_mode="fast"
|
predict_mode="fast"
|
||||||
@@ -2319,8 +2078,12 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
|
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
|
||||||
|
|
||||||
# Unpad actions
|
# Extract action tensor from output dictionary
|
||||||
actions = actions[:, :, :self.config.action_feature.shape[0]]
|
actions = output["predict_action"]
|
||||||
|
|
||||||
|
# Unpad actions to actual action dimension
|
||||||
|
action_dim = self.config.output_features["action"].shape[0]
|
||||||
|
actions = actions[:, :, :action_dim]
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|||||||
@@ -29,13 +29,10 @@ from lerobot.processor import (
|
|||||||
PolicyProcessorPipeline,
|
PolicyProcessorPipeline,
|
||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
RenameObservationsProcessorStep,
|
RenameObservationsProcessorStep,
|
||||||
TokenizerProcessorStep,
|
|
||||||
UnnormalizerProcessorStep,
|
UnnormalizerProcessorStep,
|
||||||
)
|
)
|
||||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
def make_wall_x_pre_post_processors(
|
def make_wall_x_pre_post_processors(
|
||||||
config: WallXConfig,
|
config: WallXConfig,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
@@ -49,7 +46,6 @@ def make_wall_x_pre_post_processors(
|
|||||||
The pre-processing pipeline prepares input data for the model by:
|
The pre-processing pipeline prepares input data for the model by:
|
||||||
1. Renaming features to match pretrained configurations
|
1. Renaming features to match pretrained configurations
|
||||||
2. Adding a batch dimension
|
2. Adding a batch dimension
|
||||||
3. Tokenizing language task descriptions
|
|
||||||
4. Normalizing input and output features based on dataset statistics
|
4. Normalizing input and output features based on dataset statistics
|
||||||
5. Moving all data to the specified device
|
5. Moving all data to the specified device
|
||||||
|
|
||||||
@@ -65,25 +61,10 @@ def make_wall_x_pre_post_processors(
|
|||||||
A tuple containing the configured pre-processor and post-processor pipelines
|
A tuple containing the configured pre-processor and post-processor pipelines
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Try to use Qwen processor if available
|
|
||||||
try:
|
|
||||||
from transformers import AutoProcessor
|
|
||||||
tokenizer_name = config.vlm_model_name
|
|
||||||
qwen_available = True
|
|
||||||
except ImportError:
|
|
||||||
tokenizer_name = "Qwen/Qwen2-VL-2B-Instruct" # Fallback
|
|
||||||
qwen_available = False
|
|
||||||
|
|
||||||
input_steps = [
|
input_steps = [
|
||||||
RenameObservationsProcessorStep(rename_map={}),
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
AddBatchDimensionProcessorStep(),
|
AddBatchDimensionProcessorStep(),
|
||||||
WallXTaskProcessor(), # Process task description
|
WallXTaskProcessor(), # Process task description
|
||||||
TokenizerProcessorStep(
|
|
||||||
tokenizer_name=tokenizer_name,
|
|
||||||
padding="max_length",
|
|
||||||
padding_side="right",
|
|
||||||
max_length=config.tokenizer_max_length,
|
|
||||||
),
|
|
||||||
NormalizerProcessorStep(
|
NormalizerProcessorStep(
|
||||||
features={**config.input_features, **config.output_features},
|
features={**config.input_features, **config.output_features},
|
||||||
norm_map=config.normalization_mapping,
|
norm_map=config.normalization_mapping,
|
||||||
@@ -152,30 +133,3 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
|||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register(name="wall_x_image_processor")
|
|
||||||
class WallXImageProcessor(ComplementaryDataProcessorStep):
|
|
||||||
"""
|
|
||||||
Image processor for Wall-X using Qwen-VL vision processing.
|
|
||||||
|
|
||||||
This handles image formatting according to Qwen-VL requirements.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
try:
|
|
||||||
from transformers import AutoProcessor
|
|
||||||
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
|
||||||
self.available = True
|
|
||||||
except ImportError:
|
|
||||||
self.available = False
|
|
||||||
|
|
||||||
def complementary_data(self, complementary_data):
|
|
||||||
# Image processing is handled by the VLM processor
|
|
||||||
return complementary_data
|
|
||||||
|
|
||||||
def transform_features(
|
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
|
||||||
return features
|
|
||||||
|
|||||||
@@ -33,10 +33,8 @@ from transformers import BatchFeature
|
|||||||
|
|
||||||
from lerobot.policies.wall_x.constant import (
|
from lerobot.policies.wall_x.constant import (
|
||||||
CAMERA_NAME_MAPPING,
|
CAMERA_NAME_MAPPING,
|
||||||
FREQUENCY_MAPPING,
|
|
||||||
KEY_MAPPINGS,
|
|
||||||
MULTIMODAL_DATASET_NAMES,
|
|
||||||
)
|
)
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -457,7 +455,7 @@ def get_wallx_normal_text(
|
|||||||
action_chunk_size: int,
|
action_chunk_size: int,
|
||||||
frame_idx: int,
|
frame_idx: int,
|
||||||
priority_order: Optional[OrderedDict] = None,
|
priority_order: Optional[OrderedDict] = None,
|
||||||
cam_mapping: Optional[Dict[str, str]] = None,
|
img_keys: Optional[List[str]] = None,
|
||||||
generate_subtask_ratio: float = 0.0,
|
generate_subtask_ratio: float = 0.0,
|
||||||
) -> Tuple[str, bool]:
|
) -> Tuple[str, bool]:
|
||||||
"""Construct complete multimodal prompt text for Wall-X model.
|
"""Construct complete multimodal prompt text for Wall-X model.
|
||||||
@@ -474,7 +472,7 @@ def get_wallx_normal_text(
|
|||||||
action_chunk_size: Number of action tokens to generate
|
action_chunk_size: Number of action tokens to generate
|
||||||
frame_idx: Current frame index
|
frame_idx: Current frame index
|
||||||
priority_order: Priority order for instruction sampling
|
priority_order: Priority order for instruction sampling
|
||||||
cam_mapping: Camera name mapping dictionary
|
img_keys: List of image keys
|
||||||
generate_subtask_ratio: Probability of generating subtask instead of actions
|
generate_subtask_ratio: Probability of generating subtask instead of actions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -497,10 +495,10 @@ def get_wallx_normal_text(
|
|||||||
|
|
||||||
# User request with observation
|
# User request with observation
|
||||||
user_request = f"{role_start_symbol}user\nObservation:"
|
user_request = f"{role_start_symbol}user\nObservation:"
|
||||||
if cam_mapping:
|
if img_keys:
|
||||||
for _, cam_name in cam_mapping.items():
|
img_keys = img_key_mapping(img_keys)
|
||||||
view_name = CAMERA_NAME_MAPPING.get(cam_name, cam_name)
|
for key in img_keys:
|
||||||
user_request += f" {view_name}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
|
user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
|
||||||
user_request += "\nInstruction:"
|
user_request += "\nInstruction:"
|
||||||
|
|
||||||
# Get frame-specific instruction
|
# Get frame-specific instruction
|
||||||
@@ -543,6 +541,27 @@ def get_wallx_normal_text(
|
|||||||
complete_text = prologue + user_message + assistant_output
|
complete_text = prologue + user_message + assistant_output
|
||||||
return complete_text, generate_subtask
|
return complete_text, generate_subtask
|
||||||
|
|
||||||
|
def img_key_mapping(img_keys: List[str]) -> List[str]:
|
||||||
|
"""Map image keys to camera names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_keys: List of image keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of camera names
|
||||||
|
"""
|
||||||
|
processed_img_keys = []
|
||||||
|
for key in img_keys:
|
||||||
|
key = key.replace(OBS_IMAGES + ".", "")
|
||||||
|
if key in CAMERA_NAME_MAPPING:
|
||||||
|
key = CAMERA_NAME_MAPPING[key]
|
||||||
|
else:
|
||||||
|
if 'view' in key:
|
||||||
|
key = key.replace('_', ' ')
|
||||||
|
else:
|
||||||
|
key = key + " view"
|
||||||
|
processed_img_keys.append(key)
|
||||||
|
return processed_img_keys
|
||||||
|
|
||||||
def get_action_tokens(
|
def get_action_tokens(
|
||||||
normalized_actions: Union[torch.Tensor, List], action_tokenizer
|
normalized_actions: Union[torch.Tensor, List], action_tokenizer
|
||||||
@@ -599,7 +618,6 @@ def replace_action_token(
|
|||||||
text: List[str],
|
text: List[str],
|
||||||
norm_action: Optional[torch.Tensor],
|
norm_action: Optional[torch.Tensor],
|
||||||
action_tokenizer,
|
action_tokenizer,
|
||||||
dataset_names: List[str],
|
|
||||||
dof_masks: Optional[torch.Tensor] = None,
|
dof_masks: Optional[torch.Tensor] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Replace action placeholders in text with actual action tokens.
|
"""Replace action placeholders in text with actual action tokens.
|
||||||
@@ -615,17 +633,10 @@ def replace_action_token(
|
|||||||
List of text strings with action tokens replaced
|
List of text strings with action tokens replaced
|
||||||
"""
|
"""
|
||||||
# Filter out multimodal dataset names
|
# Filter out multimodal dataset names
|
||||||
dataset_names = [
|
|
||||||
name for name in dataset_names if name not in MULTIMODAL_DATASET_NAMES
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get required action chunk sizes
|
|
||||||
required_chunk_sizes = [32 for name in dataset_names]
|
|
||||||
|
|
||||||
if action_tokenizer is not None and norm_action is not None:
|
if action_tokenizer is not None and norm_action is not None:
|
||||||
# Extract actions based on chunk sizes and DOF masks
|
# Extract actions based on chunk sizes and DOF masks
|
||||||
norm_action = [
|
norm_action = [
|
||||||
action[: required_chunk_sizes[i], dof_masks[i, 0].bool()]
|
action[: 32, dof_masks[i, 0].bool()]
|
||||||
for i, action in enumerate(norm_action)
|
for i, action in enumerate(norm_action)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
# Wall-X policy tests
|
||||||
|
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Skip this entire module in CI
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
|
reason="This test requires local Wall-X installation and is not meant for CI",
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||||
|
from lerobot.policies.wall_x import ( # noqa: E402
|
||||||
|
WallXConfig,
|
||||||
|
WallXPolicy,
|
||||||
|
make_wall_x_pre_post_processors, # noqa: E402
|
||||||
|
)
|
||||||
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||||
|
|
||||||
|
def test_policy_instantiation():
|
||||||
|
# Create config
|
||||||
|
set_seed(42)
|
||||||
|
config = WallXConfig(device='cuda')
|
||||||
|
|
||||||
|
# Set up input_features and output_features in the config
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
config.input_features = {
|
||||||
|
"observation.state": PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(7,),
|
||||||
|
),
|
||||||
|
"observation.images.face_view": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
config.output_features = {
|
||||||
|
"action": PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(7,),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create dummy dataset stats
|
||||||
|
dataset_stats = {
|
||||||
|
"observation.state": {
|
||||||
|
"mean": torch.zeros(7),
|
||||||
|
"std": torch.ones(7),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": torch.zeros(7),
|
||||||
|
"std": torch.ones(7),
|
||||||
|
},
|
||||||
|
"observation.images.face_view": {
|
||||||
|
"mean": torch.zeros(3, 224, 224),
|
||||||
|
"std": torch.ones(3, 224, 224),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Instantiate policy
|
||||||
|
policy = WallXPolicy(config)
|
||||||
|
preprocessor, postprocessor = make_wall_x_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||||
|
# Test forward pass with dummy data
|
||||||
|
batch_size = 1
|
||||||
|
device = config.device
|
||||||
|
batch = {
|
||||||
|
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
|
||||||
|
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||||
|
"observation.images.face_view": torch.rand(
|
||||||
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
|
), # Use rand for [0,1] range
|
||||||
|
"task": ["Pick up the object"] * batch_size,
|
||||||
|
}
|
||||||
|
batch = preprocessor(batch)
|
||||||
|
try:
|
||||||
|
loss, loss_dict = policy.forward(batch)
|
||||||
|
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Forward pass failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
action = policy.select_action(batch)
|
||||||
|
action = postprocessor(action)
|
||||||
|
print(f"Action: {action}")
|
||||||
|
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Action prediction failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def test_config_creation():
|
||||||
|
"""Test policy config creation through factory."""
|
||||||
|
try:
|
||||||
|
config = make_policy_config(
|
||||||
|
policy_type="wall_x",
|
||||||
|
)
|
||||||
|
print("Config created successfully through factory")
|
||||||
|
print(f" Config type: {type(config).__name__}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Config creation failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_policy_instantiation()
|
||||||
|
test_config_creation()
|
||||||
Reference in New Issue
Block a user