mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
fix bug for inference
This commit is contained in:
committed by
Michel Aractingi
parent
fc6262e23d
commit
51bd288f1a
@@ -1847,7 +1847,11 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the wall-x model
|
||||
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation)
|
||||
self.model = Qwen2_5_VLMoEForAction.from_pretrained(
|
||||
pretrained_name_or_path=config.pretrained_name_or_path,
|
||||
action_tokenizer_path=config.action_tokenizer_path,
|
||||
attn_implementation=config.attn_implementation
|
||||
)
|
||||
self.model.to(config.device)
|
||||
self.model.to_bfloat16_for_selected_params()
|
||||
|
||||
@@ -1988,6 +1992,12 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
dof_mask,
|
||||
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
|
||||
], dim=-1)
|
||||
else:
|
||||
action_dim = self.config.output_features["action"].shape[0]
|
||||
dof_mask = torch.cat([
|
||||
torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device),
|
||||
torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device)
|
||||
], dim=-1)
|
||||
|
||||
# ==================== ACTION TOKEN REPLACEMENT ====================
|
||||
all_texts = replace_action_token(
|
||||
|
||||
@@ -99,6 +99,15 @@ def test_policy_instantiation():
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
# Test inference
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.face_view": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
|
||||
Reference in New Issue
Block a user