fix bug for inference

This commit is contained in:
Geoffrey19
2025-12-16 13:21:04 +08:00
committed by Michel Aractingi
parent fc6262e23d
commit 51bd288f1a
2 changed files with 20 additions and 1 deletions
+11 -1
View File
@@ -1847,7 +1847,11 @@ class WallXPolicy(PreTrainedPolicy):
self.config = config
# Initialize the wall-x model
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation)
self.model = Qwen2_5_VLMoEForAction.from_pretrained(
pretrained_name_or_path=config.pretrained_name_or_path,
action_tokenizer_path=config.action_tokenizer_path,
attn_implementation=config.attn_implementation
)
self.model.to(config.device)
self.model.to_bfloat16_for_selected_params()
@@ -1988,6 +1992,12 @@ class WallXPolicy(PreTrainedPolicy):
dof_mask,
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
], dim=-1)
else:
action_dim = self.config.output_features["action"].shape[0]
dof_mask = torch.cat([
torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device),
torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device)
], dim=-1)
# ==================== ACTION TOKEN REPLACEMENT ====================
all_texts = replace_action_token(
+9
View File
@@ -99,6 +99,15 @@ def test_policy_instantiation():
print(f"Forward pass failed: {e}")
raise
# Test inference
batch = {
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
"observation.images.face_view": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size,
}
batch = preprocessor(batch)
try:
with torch.no_grad():
action = policy.select_action(batch)