def hidden_dimension_and_number_of_layers_from_parameter_count(
parameter_count: float,
) -> tuple[int, int]:
parameter_count_billions = parameter_count / 1e9
hidden_dimension = (
(671, 20480, 160),
(405, 16384, 120),
(120, 12288, 96),
(65, 8192, 80),
(30, 7168, 60),
(13, 5120, 40),
(7, 4096, 32),
(3, 3072, 26),
(1, 2048, 22),
(0, 1024, 18),
)
for value in hidden_dimension:
if parameter_count_billions >= value[0]:
return value[1], value[2]
message = 'Invalid parameter count'
raise ValueError(message)