@@ -190,20 +190,17 @@ std::vector<py::object> fused_attn_fwd(
190190 cu_seqlens_kv_padded.value ().data_ptr (),
191191 static_cast <NVTEShape &>(cu_seqlens_kv_padded_shape), DType::kInt32 );
192192 }
193- NVTEShape default_scale_inv_shape;
194- default_scale_inv_shape.ndim = 1 ;
195- default_scale_inv_shape.data [0 ] = 1 ;
196193 if ((page_table_k.has_value ()) && (page_table_v.has_value ())) {
197194 auto page_table_k_sizes = page_table_k.value ().sizes ().vec ();
198195 NVTEShapeWrapper page_table_k_shape{page_table_k_sizes};
199196 auto page_table_v_sizes = page_table_v.value ().sizes ().vec ();
200197 NVTEShapeWrapper page_table_v_shape{page_table_v_sizes};
201198 te_page_table_k = makeTransformerEngineTensor (
202199 page_table_k.value ().data_ptr (), static_cast <NVTEShape &>(page_table_k_shape),
203- DType::kInt32 , nullptr , nullptr , nullptr , default_scale_inv_shape );
200+ DType::kInt32 , nullptr , nullptr , nullptr , TensorWrapper::defaultShape );
204201 te_page_table_v = makeTransformerEngineTensor (
205202 page_table_v.value ().data_ptr (), static_cast <NVTEShape &>(page_table_v_shape),
206- DType::kInt32 , nullptr , nullptr , nullptr , default_scale_inv_shape );
203+ DType::kInt32 , nullptr , nullptr , nullptr , TensorWrapper::defaultShape );
207204 }
208205
209206 // softmax offset
@@ -213,7 +210,7 @@ std::vector<py::object> fused_attn_fwd(
213210 NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes};
214211 te_SoftmaxOffset = makeTransformerEngineTensor (
215212 SoftmaxOffset.value ().data_ptr (), static_cast <NVTEShape &>(SoftmaxOffset_shape),
216- DType::kFloat32 , nullptr , nullptr , nullptr , default_scale_inv_shape );
213+ DType::kFloat32 , nullptr , nullptr , nullptr , TensorWrapper::defaultShape );
217214 }
218215
219216 // extract rng seed and offset
@@ -469,16 +466,13 @@ std::vector<py::object> fused_attn_bwd(
469466 NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes};
470467 auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes ().vec ();
471468 NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes};
472- NVTEShape zero_scale_inv_shape;
473- zero_scale_inv_shape.ndim = 1 ;
474- zero_scale_inv_shape.data [0 ] = 0 ;
475469 TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
476470 te_cu_seqlens_q = makeTransformerEngineTensor (
477471 cu_seqlens_q.data_ptr (), static_cast <NVTEShape &>(cu_seqlens_q_shape), DType::kInt32 , nullptr ,
478- nullptr , nullptr , zero_scale_inv_shape );
472+ nullptr , nullptr , TensorWrapper::emptyShape );
479473 te_cu_seqlens_kv = makeTransformerEngineTensor (
480474 cu_seqlens_kv.data_ptr (), static_cast <NVTEShape &>(cu_seqlens_kv_shape), DType::kInt32 ,
481- nullptr , nullptr , nullptr , zero_scale_inv_shape );
475+ nullptr , nullptr , nullptr , TensorWrapper::emptyShape );
482476
483477 TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
484478 if ((cu_seqlens_q_padded.has_value ()) && (cu_seqlens_kv_padded.has_value ())) {
0 commit comments