Skip to content

Commit 9283238

Browse files
committed
other minor cleanups
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
1 parent 1189e51 commit 9283238

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

transformer_engine/pytorch/csrc/extensions/attention.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,20 +190,17 @@ std::vector<py::object> fused_attn_fwd(
190190
cu_seqlens_kv_padded.value().data_ptr(),
191191
static_cast<NVTEShape &>(cu_seqlens_kv_padded_shape), DType::kInt32);
192192
}
193-
NVTEShape default_scale_inv_shape;
194-
default_scale_inv_shape.ndim = 1;
195-
default_scale_inv_shape.data[0] = 1;
196193
if ((page_table_k.has_value()) && (page_table_v.has_value())) {
197194
auto page_table_k_sizes = page_table_k.value().sizes().vec();
198195
NVTEShapeWrapper page_table_k_shape{page_table_k_sizes};
199196
auto page_table_v_sizes = page_table_v.value().sizes().vec();
200197
NVTEShapeWrapper page_table_v_shape{page_table_v_sizes};
201198
te_page_table_k = makeTransformerEngineTensor(
202199
page_table_k.value().data_ptr(), static_cast<NVTEShape &>(page_table_k_shape),
203-
DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape);
200+
DType::kInt32, nullptr, nullptr, nullptr, TensorWrapper::defaultShape);
204201
te_page_table_v = makeTransformerEngineTensor(
205202
page_table_v.value().data_ptr(), static_cast<NVTEShape &>(page_table_v_shape),
206-
DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape);
203+
DType::kInt32, nullptr, nullptr, nullptr, TensorWrapper::defaultShape);
207204
}
208205

209206
// softmax offset
@@ -213,7 +210,7 @@ std::vector<py::object> fused_attn_fwd(
213210
NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes};
214211
te_SoftmaxOffset = makeTransformerEngineTensor(
215212
SoftmaxOffset.value().data_ptr(), static_cast<NVTEShape &>(SoftmaxOffset_shape),
216-
DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape);
213+
DType::kFloat32, nullptr, nullptr, nullptr, TensorWrapper::defaultShape);
217214
}
218215

219216
// extract rng seed and offset
@@ -469,16 +466,13 @@ std::vector<py::object> fused_attn_bwd(
469466
NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes};
470467
auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec();
471468
NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes};
472-
NVTEShape zero_scale_inv_shape;
473-
zero_scale_inv_shape.ndim = 1;
474-
zero_scale_inv_shape.data[0] = 0;
475469
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
476470
te_cu_seqlens_q = makeTransformerEngineTensor(
477471
cu_seqlens_q.data_ptr(), static_cast<NVTEShape &>(cu_seqlens_q_shape), DType::kInt32, nullptr,
478-
nullptr, nullptr, zero_scale_inv_shape);
472+
nullptr, nullptr, TensorWrapper::emptyShape);
479473
te_cu_seqlens_kv = makeTransformerEngineTensor(
480474
cu_seqlens_kv.data_ptr(), static_cast<NVTEShape &>(cu_seqlens_kv_shape), DType::kInt32,
481-
nullptr, nullptr, nullptr, zero_scale_inv_shape);
475+
nullptr, nullptr, nullptr, TensorWrapper::emptyShape);
482476

483477
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
484478
if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {

0 commit comments

Comments
 (0)