From 51bd288f1a87a615c65d04b7413f269ebadcc6e8 Mon Sep 17 00:00:00 2001 From: Geoffrey19 Date: Tue, 16 Dec 2025 13:21:04 +0800 Subject: [PATCH] fix bug for inference --- src/lerobot/policies/wall_x/modeling_wall_x.py | 12 +++++++++++- tests/policies/wall_x/test_wallx.py | 9 +++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index 15a162c78..16175127d 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -1847,7 +1847,11 @@ class WallXPolicy(PreTrainedPolicy): self.config = config # Initialize the wall-x model - self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation) + self.model = Qwen2_5_VLMoEForAction.from_pretrained( + pretrained_name_or_path=config.pretrained_name_or_path, + action_tokenizer_path=config.action_tokenizer_path, + attn_implementation=config.attn_implementation + ) self.model.to(config.device) self.model.to_bfloat16_for_selected_params() @@ -1988,6 +1992,12 @@ class WallXPolicy(PreTrainedPolicy): dof_mask, torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device) ], dim=-1) + else: + action_dim = self.config.output_features["action"].shape[0] + dof_mask = torch.cat([ + torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device), + torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device) + ], dim=-1) # ==================== ACTION TOKEN REPLACEMENT ==================== all_texts = replace_action_token( diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py index 2440ec98b..8286655f0 100644 --- a/tests/policies/wall_x/test_wallx.py +++ b/tests/policies/wall_x/test_wallx.py @@ -99,6 +99,15 @@ def test_policy_instantiation(): print(f"Forward pass failed: {e}") raise + # Test inference + batch = { + "observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device), + "observation.images.face_view": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) try: with torch.no_grad(): action = policy.select_action(batch)