mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
fix bug for inference
This commit is contained in:
committed by
Michel Aractingi
parent
fc6262e23d
commit
51bd288f1a
@@ -1847,7 +1847,11 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Initialize the wall-x model
|
# Initialize the wall-x model
|
||||||
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation)
|
self.model = Qwen2_5_VLMoEForAction.from_pretrained(
|
||||||
|
pretrained_name_or_path=config.pretrained_name_or_path,
|
||||||
|
action_tokenizer_path=config.action_tokenizer_path,
|
||||||
|
attn_implementation=config.attn_implementation
|
||||||
|
)
|
||||||
self.model.to(config.device)
|
self.model.to(config.device)
|
||||||
self.model.to_bfloat16_for_selected_params()
|
self.model.to_bfloat16_for_selected_params()
|
||||||
|
|
||||||
@@ -1988,6 +1992,12 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
dof_mask,
|
dof_mask,
|
||||||
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
|
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
|
||||||
], dim=-1)
|
], dim=-1)
|
||||||
|
else:
|
||||||
|
action_dim = self.config.output_features["action"].shape[0]
|
||||||
|
dof_mask = torch.cat([
|
||||||
|
torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device),
|
||||||
|
torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device)
|
||||||
|
], dim=-1)
|
||||||
|
|
||||||
# ==================== ACTION TOKEN REPLACEMENT ====================
|
# ==================== ACTION TOKEN REPLACEMENT ====================
|
||||||
all_texts = replace_action_token(
|
all_texts = replace_action_token(
|
||||||
|
|||||||
@@ -99,6 +99,15 @@ def test_policy_instantiation():
|
|||||||
print(f"Forward pass failed: {e}")
|
print(f"Forward pass failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Test inference
|
||||||
|
batch = {
|
||||||
|
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
|
||||||
|
"observation.images.face_view": torch.rand(
|
||||||
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
|
), # Use rand for [0,1] range
|
||||||
|
"task": ["Pick up the object"] * batch_size,
|
||||||
|
}
|
||||||
|
batch = preprocessor(batch)
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action = policy.select_action(batch)
|
action = policy.select_action(batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user