from dataclasses import dataclass @dataclass class Config: width: int depth: int mlp_dim: int num_heads: int num_kv_heads: int head_dim: int def get_config(variant: str) -> Config: """Return the Gemma shape config needed by the OpenPI PyTorch model.""" if variant == "dummy": return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16) if variant == "gemma_300m": return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256) if variant == "gemma_2b": return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256) raise ValueError(f"Unknown variant: {variant}")