Skip to content

Commit 39770f4

Browse files
committed
Merge branch 'hierarchical-helpers' of https://github.com/tarang-jain/cuvs into hierarchical-helpers
2 parents 7c8fa6e + 1a6a145 commit 39770f4

9 files changed

Lines changed: 71 additions & 116 deletions

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ repos:
100100
files: |
101101
(?x)
102102
[.](cpp|cu|hpp|cuh)[.]in$
103+
# clang-format struggles with one of the placeholders in register_fatbin.cpp.in, so exclude it
104+
exclude: |
105+
(?x)
106+
^cpp/cmake/modules/register_fatbin[.]cpp[.]in$
103107
- repo: https://github.com/codespell-project/codespell
104108
rev: v2.4.1
105109
hooks:

cpp/CMakeLists.txt

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ if(NOT BUILD_CPU_ONLY)
401401

402402
block(PROPAGATE interleaved_scan_files metric_files filter_files post_lambda_files)
403403
set(CMAKE_CUDA_ARCHITECTURES ${JIT_LTO_TARGET_ARCHITECTURE})
404+
set(ivf_flat_ns "cuvs::neighbors::ivf_flat::detail")
404405
generate_jit_lto_kernels(
405406
interleaved_scan_files
406407
NAME_FORMAT
@@ -409,8 +410,10 @@ if(NOT BUILD_CPU_ONLY)
409410
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_matrix.json"
410411
KERNEL_INPUT_FILE
411412
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in"
412-
EMBEDDED_INPUT_FILE
413-
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_embedded.cpp.in"
413+
FRAGMENT_TAG_FORMAT
414+
"${ivf_flat_ns}::fragment_tag_interleaved_scan<${ivf_flat_ns}::tag_@type_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${ivf_flat_ns}::tag_idx_@idx_abbrev@, @capacity@, @veclen@, @ascending_value@, @compute_norm_value@>"
415+
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp>"
416+
"<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp>"
414417
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/interleaved_scan"
415418
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
416419
)
@@ -421,8 +424,10 @@ if(NOT BUILD_CPU_ONLY)
421424
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/metric_matrix.json"
422425
KERNEL_INPUT_FILE
423426
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/metric_kernel.cu.in"
424-
EMBEDDED_INPUT_FILE
425-
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/metric_embedded.cpp.in"
427+
FRAGMENT_TAG_FORMAT
428+
"${ivf_flat_ns}::fragment_tag_metric<@veclen@, ${ivf_flat_ns}::tag_@type_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${ivf_flat_ns}::tag_metric_@metric_name@>"
429+
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp>"
430+
"<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp>"
426431
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/metric"
427432
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
428433
)
@@ -433,8 +438,10 @@ if(NOT BUILD_CPU_ONLY)
433438
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/filter_matrix.json"
434439
KERNEL_INPUT_FILE
435440
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/filter_kernel.cu.in"
436-
EMBEDDED_INPUT_FILE
437-
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/filter_embedded.cpp.in"
441+
FRAGMENT_TAG_FORMAT
442+
"${ivf_flat_ns}::fragment_tag_filter<${ivf_flat_ns}::tag_filter<${ivf_flat_ns}::tag_idx_l, ${ivf_flat_ns}::tag_@filter_name@_impl>>"
443+
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp>"
444+
"<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp>"
438445
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/filter"
439446
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
440447
)
@@ -445,8 +452,10 @@ if(NOT BUILD_CPU_ONLY)
445452
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_matrix.json"
446453
KERNEL_INPUT_FILE
447454
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_kernel.cu.in"
448-
EMBEDDED_INPUT_FILE
449-
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_embedded.cpp.in"
455+
FRAGMENT_TAG_FORMAT
456+
"${ivf_flat_ns}::fragment_tag_post_lambda<${ivf_flat_ns}::tag_@post_lambda_name@>"
457+
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp>"
458+
"<cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp>"
450459
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/post_lambda"
451460
KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements
452461
)

cpp/cmake/config.json

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
},
2727
"kwargs": {
2828
"KERNEL_FILE": 1,
29-
"EMBEDDED_HEADER_FILE": 1,
29+
"FATBIN_HEADER_FILE": 1,
3030
"LINK_LIBRARIES": "*"
3131
}
3232
},
@@ -37,7 +37,8 @@
3737
"kwargs": {
3838
"NAME_FORMAT": 1,
3939
"KERNEL_INPUT_FILE": 1,
40-
"EMBEDDED_INPUT_FILE": 1,
40+
"FRAGMENT_TAG_FORMAT": 1,
41+
"FRAGMENT_TAG_HEADER_FILES": "*",
4142
"OUTPUT_DIRECTORY": 1,
4243
"MATRIX_JSON_ENTRY": 1,
4344
"KERNEL_LINK_LIBRARIES": "*"
@@ -52,7 +53,8 @@
5253
"MATRIX_JSON_FILE": "?",
5354
"MATRIX_JSON_STRING": "?",
5455
"KERNEL_INPUT_FILE": 1,
55-
"EMBEDDED_INPUT_FILE": 1,
56+
"FRAGMENT_TAG_FORMAT": 1,
57+
"FRAGMENT_TAG_HEADER_FILES": "*",
5658
"OUTPUT_DIRECTORY": 1,
5759
"KERNEL_LINK_LIBRARIES": "*"
5860
}

cpp/cmake/modules/generate_jit_lto_kernels.cmake

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/compute_matrix_product.cmake)
1111

1212
function(add_jit_lto_kernel kernel_target)
1313
set(options)
14-
set(one_value KERNEL_FILE EMBEDDED_HEADER_FILE)
14+
set(one_value KERNEL_FILE FATBIN_HEADER_FILE)
1515
set(multi_value LINK_LIBRARIES)
1616

1717
cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN})
@@ -31,39 +31,48 @@ function(add_jit_lto_kernel kernel_target)
3131
)
3232

3333
add_custom_command(
34-
OUTPUT "${_JIT_LTO_EMBEDDED_HEADER_FILE}"
34+
OUTPUT "${_JIT_LTO_FATBIN_HEADER_FILE}"
3535
COMMAND "${bin_to_c}" --const --name embedded_fatbin --static $<TARGET_OBJECTS:${kernel_target}>
36-
> "${_JIT_LTO_EMBEDDED_HEADER_FILE}"
36+
> "${_JIT_LTO_FATBIN_HEADER_FILE}"
3737
DEPENDS $<TARGET_OBJECTS:${kernel_target}>
3838
)
3939
endfunction()
4040

4141
function(process_jit_lto_matrix_entry source_list_var)
4242
set(options)
43-
set(one_value NAME_FORMAT KERNEL_INPUT_FILE EMBEDDED_INPUT_FILE OUTPUT_DIRECTORY
43+
set(one_value NAME_FORMAT KERNEL_INPUT_FILE OUTPUT_DIRECTORY FRAGMENT_TAG_FORMAT
4444
MATRIX_JSON_ENTRY
4545
)
46-
set(multi_value KERNEL_LINK_LIBRARIES)
46+
set(multi_value KERNEL_LINK_LIBRARIES FRAGMENT_TAG_HEADER_FILES)
4747

4848
cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN})
4949

5050
populate_matrix_variables("${_JIT_LTO_MATRIX_JSON_ENTRY}")
5151
string(CONFIGURE "${_JIT_LTO_NAME_FORMAT}" kernel_name @ONLY)
52+
string(CONFIGURE "${_JIT_LTO_FRAGMENT_TAG_FORMAT}" fragment_tag @ONLY)
53+
54+
set(fragment_tag_header_files "")
55+
foreach(header_file IN LISTS _JIT_LTO_FRAGMENT_TAG_HEADER_FILES)
56+
if(NOT header_file MATCHES "^(\".*\"|<.*>)$")
57+
set(header_file "\"${header_file}\"")
58+
endif()
59+
string(APPEND fragment_tag_header_files "#include ${header_file}\n")
60+
endforeach()
5261

5362
set(kernel_file "${_JIT_LTO_OUTPUT_DIRECTORY}/${kernel_name}_kernel.cu")
5463
set(kernel_target "${kernel_name}_kernel")
55-
set(embedded_header_file "${_JIT_LTO_OUTPUT_DIRECTORY}/${kernel_name}_embedded.h")
56-
set(embedded_file "${_JIT_LTO_OUTPUT_DIRECTORY}/${kernel_name}_embedded.cpp")
64+
set(fatbin_header_file "${_JIT_LTO_OUTPUT_DIRECTORY}/${kernel_name}_fatbin.h")
65+
set(fatbin_file "${_JIT_LTO_OUTPUT_DIRECTORY}/${kernel_name}_fatbin.cpp")
5766
configure_file("${_JIT_LTO_KERNEL_INPUT_FILE}" "${kernel_file}" @ONLY)
58-
configure_file("${_JIT_LTO_EMBEDDED_INPUT_FILE}" "${embedded_file}" @ONLY)
67+
configure_file("${CMAKE_CURRENT_FUNCTION_LIST_DIR}/register_fatbin.cpp.in" "${fatbin_file}" @ONLY)
5968

6069
add_jit_lto_kernel(
6170
${kernel_target}
6271
KERNEL_FILE "${kernel_file}"
63-
EMBEDDED_HEADER_FILE "${embedded_header_file}"
72+
FATBIN_HEADER_FILE "${fatbin_header_file}"
6473
LINK_LIBRARIES ${_JIT_LTO_KERNEL_LINK_LIBRARIES}
6574
)
66-
list(APPEND ${source_list_var} "${embedded_header_file}" "${embedded_file}")
75+
list(APPEND ${source_list_var} "${fatbin_header_file}" "${fatbin_file}")
6776
set(${source_list_var}
6877
"${${source_list_var}}"
6978
PARENT_SCOPE
@@ -73,9 +82,9 @@ endfunction()
7382
function(generate_jit_lto_kernels source_list_var)
7483
set(options)
7584
set(one_value NAME_FORMAT MATRIX_JSON_FILE MATRIX_JSON_STRING KERNEL_INPUT_FILE
76-
EMBEDDED_INPUT_FILE OUTPUT_DIRECTORY
85+
FRAGMENT_TAG_FORMAT OUTPUT_DIRECTORY
7786
)
78-
set(multi_value KERNEL_LINK_LIBRARIES)
87+
set(multi_value KERNEL_LINK_LIBRARIES FRAGMENT_TAG_HEADER_FILES)
7988

8089
cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN})
8190

@@ -107,7 +116,8 @@ function(generate_jit_lto_kernels source_list_var)
107116
"${source_list_var}"
108117
NAME_FORMAT "${_JIT_LTO_NAME_FORMAT}"
109118
KERNEL_INPUT_FILE "${_JIT_LTO_KERNEL_INPUT_FILE}"
110-
EMBEDDED_INPUT_FILE "${_JIT_LTO_EMBEDDED_INPUT_FILE}"
119+
FRAGMENT_TAG_FORMAT "${_JIT_LTO_FRAGMENT_TAG_FORMAT}"
120+
FRAGMENT_TAG_HEADER_FILES ${_JIT_LTO_FRAGMENT_TAG_HEADER_FILES}
111121
OUTPUT_DIRECTORY "${_JIT_LTO_OUTPUT_DIRECTORY}"
112122
MATRIX_JSON_ENTRY "${matrix_json_entry}"
113123
KERNEL_LINK_LIBRARIES ${_JIT_LTO_KERNEL_LINK_LIBRARIES}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#include "@fatbin_header_file@"
7+
#include <cuvs/detail/jit_lto/FragmentEntry.hpp>
8+
9+
@fragment_tag_header_files@
10+
11+
namespace {
12+
13+
using fragment_tag = @fragment_tag@;
14+
using fragment_entry = StaticFatbinFragmentEntry<fragment_tag>;
15+
16+
} // namespace
17+
18+
template <>
19+
const uint8_t* const fragment_entry::data = embedded_fatbin;
20+
21+
template <>
22+
const size_t fragment_entry::length = sizeof(embedded_fatbin);

cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_embedded.cpp.in

Lines changed: 0 additions & 22 deletions
This file was deleted.

cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_embedded.cpp.in

Lines changed: 0 additions & 27 deletions
This file was deleted.

cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_embedded.cpp.in

Lines changed: 0 additions & 22 deletions
This file was deleted.

cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_embedded.cpp.in

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)