@@ -60,6 +60,53 @@ def get_shapes_for_config(
6060 "ffn.w2" : (M , 3584 , 8192 ),
6161 }
6262 shapes .extend ([(f"{ name } _{ k } " , v ) for k , v in llama_shapes .items ()])
63+ elif name == "llama4" :
64+ # LLaMa 4 shapes
65+ llama4_shapes = [
66+ ("FFN" , (16384 , 8192 , 5120 )),
67+ ("QO_proj" , (16384 , 8192 , 8192 )),
68+ ("KV_proj" , (16384 , 8192 , 1024 )),
69+ ("FFN" , (128000 , 8192 , 5120 )),
70+ ("QO_proj" , (128000 , 8192 , 8192 )),
71+ ("KV_proj" , (128000 , 8192 , 1024 )),
72+ ]
73+ shapes .extend ([(f"{ name } _{ k } " , v ) for k , v in llama4_shapes ])
74+ elif name == "deepseek_v3_236b" :
75+ # DeepSeek V3 236B shapes
76+ deepseek_v3_236b_shapes = [
77+ ("FFN" , (16384 , 1536 , 5120 )),
78+ ("QKVO_proj" , (16384 , 7168 , 7168 )),
79+ ("FFN" , (128000 , 1536 , 5120 )),
80+ ("QKVO_proj" , (128000 , 7168 , 7168 )),
81+ ]
82+ shapes .extend ([(f"{ name } _{ k } " , v ) for k , v in deepseek_v3_236b_shapes ])
83+ elif name == "deepseek_v3_671b" :
84+ # DeepSeek V3 671B shapes
85+ deepseek_v3_671b_shapes = [
86+ ("FFN" , (16384 , 2048 , 7168 )),
87+ ("QKVO_proj" , (16384 , 7168 , 7168 )),
88+ ("FFN" , (128000 , 2048 , 7168 )),
89+ ("QKVO_proj" , (128000 , 7168 , 7168 )),
90+ ]
91+ shapes .extend ([(f"{ name } _{ k } " , v ) for k , v in deepseek_v3_671b_shapes ])
92+ elif name == "qwen3_32b" :
93+ # Qwen3 32B shapes
94+ qwen3_32b_shapes = [
95+ ("QO_proj" , (16384 , 5120 , 5120 )),
96+ ("KV_proj" , (16384 , 5120 , 640 )),
97+ ("QO_proj" , (128000 , 5120 , 5120 )),
98+ ("KV_proj" , (128000 , 5120 , 640 )),
99+ ]
100+ shapes .extend ([(f"{ name } _{ k } " , v ) for k , v in qwen3_32b_shapes ])
101+ elif name == "gemma3_27b" :
102+ # Gemma3 27B shapes
103+ gemma3_27b_shapes = [
104+ ("QO_proj" , (16384 , 4096 , 4096 )),
105+ ("KV_proj" , (16384 , 4096 , 1024 )),
106+ ("QO_proj" , (128000 , 4096 , 4096 )),
107+ ("KV_proj" , (128000 , 4096 , 1024 )),
108+ ]
109+ shapes .extend ([(f"{ name } _{ k } " , v ) for k , v in gemma3_27b_shapes ])
63110 elif name == "pow2" :
64111 # Generate shapes with dimensions that are powers of 2
65112 min_power_of_2 = shape_config .get ("min_power" , 10 ) # 1024
@@ -105,7 +152,7 @@ def get_shapes_for_config(
105152 counter += 1
106153 else :
107154 raise NotImplementedError (
108- f"Shape config { name } not supported. Supported options: custom, llama, pow2, pow2_extended, sweep."
155+ f"Shape config { name } not supported. Supported options: custom, llama, llama4, deepseek_v3_236b, deepseek_v3_671b, qwen3_32b, gemma3_27b, pow2, pow2_extended, sweep."
109156 )
110157 return shapes
111158
0 commit comments