@@ -3062,10 +3062,11 @@ namespace experimental::udf {
30623062/* *
30633063 * @brief Wrapper for vector elements that provides both packed and unpacked access.
30643064 *
3065- * For float/half : trivial wrapper around scalar values
3065+ * For float: trivial wrapper around scalar values
30663066 * For int8/uint8 with Veclen > 1: wraps packed bytes in a 32-bit word
30673067 *
3068- * @tparam T Data type (float, __half, int8_t, uint8_t)
3068+ * @tparam T Data type (float, int8_t, uint8_t). Fp16 vector elements are not supported for UDFs
3069+ * at this time (see `metric_interface` static_assert when `cuda_fp16.h` is available).
30693070 * @tparam AccT Storage/accumulator type (float, __half, int32_t, uint32_t)
30703071 * @tparam Veclen Vector length (1, 2, 4, 8, 16)
30713072 */
@@ -3130,6 +3131,13 @@ template <typename T, typename AccT, int Veclen = 1>
31303131struct metric_interface {
31313132 using point_type = point<T, AccT, Veclen>;
31323133
3134+ #if CUVS_IVF_FLAT_UDF_HAS_CUDA_FP16
3135+ static_assert (
3136+ !(std::is_same_v<std::remove_cv_t <T>, __half> || std::is_same_v<std::remove_cv_t <T>, half>),
3137+ "IVF-Flat custom metric UDF does not support fp16 (__half / half) at this time; do not set "
3138+ " search_params.metric_udf for fp16 indices." );
3139+ #endif
3140+
31333141 virtual __device__ void operator()(AccT& acc, point_type x, point_type y) = 0;
31343142 virtual ~metric_interface() = default;
31353143};
@@ -3380,8 +3388,6 @@ __device__ __forceinline__ AccT max_elem(point<T, AccT, V> x, point<T, AccT, V>
33803388 * the necessary types and utilities inline.
33813389 */
33823390constexpr std::string_view jit_preamble_code = R" (
3383- #include <cuda_fp16.h>
3384-
33853391/* Fixed-width integer types for nvrtc */
33863392using int8_t = signed char ;
33873393using uint8_t = unsigned char ;
0 commit comments