Skip to content
Open
202 changes: 202 additions & 0 deletions python-package/xgboost/testing/quantile_dmatrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""QuantileDMatrix related tests."""

from dataclasses import dataclass
from typing import Any, Callable, Optional

import numpy as np
import pytest
from sklearn.model_selection import train_test_split
Expand All @@ -8,6 +11,203 @@

from .data import make_batches, make_categorical

MAX_NORMALIZED_RANK_ERROR = 2.0
MAX_WEIGHTED_NORMALIZED_RANK_ERROR = 14.0
Comment on lines +14 to +15
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please provide some brief comments on utilities here?



@dataclass(frozen=True)
class _RankContext:
sorted_x: np.ndarray
prefix_sum: np.ndarray
total_weight: float
num_cuts: int
avg_bin_weight: float


def _to_numpy(data: Any) -> np.ndarray:
if hasattr(data, "get"):
data = data.get()
elif hasattr(data, "to_pandas"):
data = data.to_pandas()
if hasattr(data, "to_numpy"):
data = data.to_numpy()
return np.asarray(data)


def _distance_to_interval(target: float, lo: float, hi: float) -> float:
if target < lo:
return lo - target
if target > hi:
return target - hi
return 0.0


def _prepare_validation_input(
x: Any, w: Optional[Any]
) -> tuple[Any, np.ndarray, float]:
x_data = x.get() if hasattr(x, "get") else x
if hasattr(x_data, "to_pandas"):
x_data = x_data.to_pandas()

if w is None:
weights = np.ones(x_data.shape[0], dtype=np.float64)
else:
weights = _to_numpy(w).astype(np.float64, copy=False)
assert weights.ndim == 1
assert weights.shape[0] == x_data.shape[0]

max_rank_error = (
MAX_NORMALIZED_RANK_ERROR
if np.all(weights == 1.0)
else MAX_WEIGHTED_NORMALIZED_RANK_ERROR
)
return x_data, weights, max_rank_error


def _column_getter(
x_data: Any, weights: np.ndarray
) -> tuple[int, Callable[[int], tuple[np.ndarray, np.ndarray]]]:
if hasattr(x_data, "tocsc") and hasattr(x_data, "indptr"):
csc = x_data.tocsc()

def get_sparse_column(fidx: int) -> tuple[np.ndarray, np.ndarray]:
beg = int(csc.indptr[fidx])
end = int(csc.indptr[fidx + 1])
indices = csc.indices[beg:end]
column = np.asarray(csc.data[beg:end])
return column, weights[indices]

return csc.shape[1], get_sparse_column

x_dense = _to_numpy(x_data)
assert x_dense.ndim == 2

def get_dense_column(fidx: int) -> tuple[np.ndarray, np.ndarray]:
column = x_dense[:, fidx]
valid = ~np.isnan(column)
return column[valid], weights[valid]

return x_dense.shape[1], get_dense_column


def _sorted_rank_state(
column: np.ndarray, column_w: np.ndarray
) -> tuple[np.ndarray, np.ndarray, float]:
sorted_idx = np.argsort(column, kind="stable")
sorted_x = column[sorted_idx]
sorted_w = column_w[sorted_idx]
prefix_sum = np.concatenate(([0.0], np.cumsum(sorted_w, dtype=np.float64)))
return sorted_x, prefix_sum, float(prefix_sum[-1])


def _make_rank_context(
column: np.ndarray, column_w: np.ndarray, column_cuts: np.ndarray
) -> _RankContext | None:
sorted_x, prefix_sum, total_weight = _sorted_rank_state(column, column_w)
if total_weight == 0.0:
return None
return _RankContext(
sorted_x=sorted_x,
prefix_sum=prefix_sum,
total_weight=total_weight,
num_cuts=column_cuts.shape[0],
avg_bin_weight=total_weight / float(column_cuts.shape[0]),
)


def _rank_error_candidate(
cut_idx: int,
cut: float,
rank_ctx: _RankContext,
) -> tuple[float, dict[str, float | int]]:
rank_lo = float(
rank_ctx.prefix_sum[np.searchsorted(rank_ctx.sorted_x, cut, side="left")]
)
rank_hi = float(
rank_ctx.prefix_sum[np.searchsorted(rank_ctx.sorted_x, cut, side="right")]
)
target_rank = ((cut_idx + 1) * rank_ctx.total_weight) / float(rank_ctx.num_cuts)
absolute_error = _distance_to_interval(target_rank, rank_lo, rank_hi)
return absolute_error / rank_ctx.avg_bin_weight, {
"cut": cut_idx,
"absolute_error": absolute_error,
"target_rank": target_rank,
"rank_lo": rank_lo,
"rank_hi": rank_hi,
}


def _max_rank_error_for_column(
column: np.ndarray, column_w: np.ndarray, column_cuts: np.ndarray
) -> tuple[float, str]:
rank_ctx = _make_rank_context(column, column_w, column_cuts)
if rank_ctx is None:
return 0.0, ""

max_error = 0.0
max_state = {
"cut": 0,
"absolute_error": 0.0,
"target_rank": 0.0,
"rank_lo": 0.0,
"rank_hi": 0.0,
}
for cut_idx, cut in enumerate(column_cuts[:-1]):
error, state = _rank_error_candidate(cut_idx, cut, rank_ctx)
if error > max_error:
max_error = error
max_state = state

details = (
f"cut={max_state['cut']}, normalized_error={max_error}, "
f"absolute_error={max_state['absolute_error']}, "
f"target_rank={max_state['target_rank']}, rank_lo={max_state['rank_lo']}, "
f"rank_hi={max_state['rank_hi']}, total_weight={rank_ctx.total_weight}, "
f"num_cuts={column_cuts.shape[0]}"
)
return max_error, details


def _assert_feature_rank_error(
indptr: np.ndarray,
cuts: np.ndarray,
get_column: Callable[[int], tuple[np.ndarray, np.ndarray]],
fidx: int,
max_normalized_rank_error: float,
) -> None:
column, column_w = get_column(fidx)
if column.shape[0] == 0:
return

beg = int(indptr[fidx])
end = int(indptr[fidx + 1])
column_cuts = cuts[beg:end]
assert np.all(np.diff(column_cuts) >= 0.0)
if column_cuts.shape[0] <= 1:
return

max_error, details = _max_rank_error_for_column(column, column_w, column_cuts)
assert max_error <= max_normalized_rank_error, f"feature={fidx}, {details}"


def assert_cut_rank_error_within_tolerance(
indptr: np.ndarray,
cuts: np.ndarray,
x: Any,
w: Optional[Any] = None,
max_normalized_rank_error: Optional[float] = None,
) -> None:
"""Assert that every numerical feature cut stays within the allowed rank error."""
x_data, weights, default_rank_error = _prepare_validation_input(x, w)
if max_normalized_rank_error is None:
max_normalized_rank_error = default_rank_error

n_features, get_column = _column_getter(x_data, weights)
for fidx in range(n_features):
_assert_feature_rank_error(
indptr, cuts, get_column, fidx, max_normalized_rank_error
)


def check_ref_quantile_cut(device: str) -> None:
"""Check obtaining the same cut values given a reference."""
Expand All @@ -30,10 +230,12 @@ def check_ref_quantile_cut(device: str) -> None:

np.testing.assert_allclose(cut_train[0], cut_valid[0])
np.testing.assert_allclose(cut_train[1], cut_valid[1])
assert_cut_rank_error_within_tolerance(cut_train[0], cut_train[1], X_train)

Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid)
cut_valid = Xy_valid.get_quantile_cut()
assert not np.allclose(cut_train[1], cut_valid[1])
assert_cut_rank_error_within_tolerance(cut_valid[0], cut_valid[1], X_valid)


def check_categorical_strings(device: str) -> None:
Expand Down
3 changes: 3 additions & 0 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..training import train
from .data import IteratorForTest, make_batches, make_categorical
from .data_iter import CatIter
from .quantile_dmatrix import assert_cut_rank_error_within_tolerance
from .utils import Device, assert_allclose, non_increasing


Expand Down Expand Up @@ -334,11 +335,13 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
Xyw: DMatrix = QuantileDMatrix(X, y, weight=w, max_bin=max_bin)
indptr, data = Xyw.get_quantile_cut()
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
assert_cut_rank_error_within_tolerance(indptr, data, X, w)
# - dm
Xyw = DMatrix(X, y, weight=w)
train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
indptr, data = Xyw.get_quantile_cut()
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
assert_cut_rank_error_within_tolerance(indptr, data, X, w)
# - ext mem
n_batches = 3
n_samples_per_batch = 256
Expand Down
4 changes: 1 addition & 3 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ constexpr float SketchContainer::kFactor;

namespace detail {
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
double eps = 1.0 / (WQSketch::kFactor * max_bins);
size_t num_cuts = WQuantileSketch::LimitSizeLevel(num_rows, eps);
return std::min(num_cuts, num_rows);
return std::min(SketchSummaryBudget(max_bins, num_rows), num_rows);
}

size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t max_bins,
Expand Down
Loading
Loading