Files
lerobot/tests/policies/pi0_pi05/openpi_pytorch/gemma.py
T
Haoming Song b74a551d38 fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash

* fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions

* fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete

* feat(test): add compile test and benchamrk for pi0 and pi05

* feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
2026-05-22 10:29:34 +02:00

23 lines
726 B
Python

from dataclasses import dataclass
@dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
def get_config(variant: str) -> Config:
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
if variant == "dummy":
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
if variant == "gemma_300m":
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
if variant == "gemma_2b":
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
raise ValueError(f"Unknown variant: {variant}")