mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
b74a551d38
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
23 lines
726 B
Python
23 lines
726 B
Python
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
width: int
|
|
depth: int
|
|
mlp_dim: int
|
|
num_heads: int
|
|
num_kv_heads: int
|
|
head_dim: int
|
|
|
|
|
|
def get_config(variant: str) -> Config:
|
|
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
|
|
if variant == "dummy":
|
|
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
|
|
if variant == "gemma_300m":
|
|
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
|
|
if variant == "gemma_2b":
|
|
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
|
|
raise ValueError(f"Unknown variant: {variant}")
|