mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
|
||||
|
||||
def get_config(variant: str) -> Config:
|
||||
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
|
||||
if variant == "dummy":
|
||||
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
|
||||
if variant == "gemma_300m":
|
||||
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
if variant == "gemma_2b":
|
||||
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
Reference in New Issue
Block a user