diff --git a/src/common/device_vector.cuh b/src/common/device_vector.cuh index 8a561878de57..21af4dd95520 100644 --- a/src/common/device_vector.cuh +++ b/src/common/device_vector.cuh @@ -551,9 +551,13 @@ class DeviceUVectorImpl { void resize(std::size_t n) { // NOLINT using ::xgboost::common::SizeBytes; - if (n <= this->Capacity()) { + if (n == 0) { + return this->clear(); + } + // n is at the second half of the dynamic table, avoid re-allocation. + if (this->Capacity() / 2 <= n && n <= this->Capacity()) { this->size_ = n; - // early exit as no allocation is needed. + // Early exit return; } CHECK_LE(this->size(), this->Capacity()); @@ -568,10 +572,9 @@ class DeviceUVectorImpl { CHECK(new_ptr.get()); auto s = ::xgboost::curt::DefaultStream(); - safe_cuda(cudaMemcpyAsync(new_ptr.get(), this->data(), SizeBytes(this->size()), - cudaMemcpyDefault, s)); - this->size_ = n; - this->capacity_ = n; + std::size_t n_bytes = std::min(SizeBytes(this->size()), SizeBytes(n)); + safe_cuda(cudaMemcpyAsync(new_ptr.get(), this->data(), n_bytes, cudaMemcpyDefault, s)); + this->capacity_ = this->size_ = n; this->data_ = std::move(new_ptr); // swap failed with CTK12.8 @@ -588,7 +591,8 @@ class DeviceUVectorImpl { } void clear() { // NOLINT - this->resize(0); + this->data_.reset(); + this->capacity_ = this->size_ = 0; } [[nodiscard]] std::size_t size() const { return this->size_; } // NOLINT diff --git a/tests/cpp/common/test_device_vector.cu b/tests/cpp/common/test_device_vector.cu index 16a847648eb7..f5edcf1ce48d 100644 --- a/tests/cpp/common/test_device_vector.cu +++ b/tests/cpp/common/test_device_vector.cu @@ -43,25 +43,42 @@ TEST(DeviceUVector, Basic) { ASSERT_EQ(peak, n_bytes); std::swap(verbosity, xgboost::GlobalConfigThreadLocalStore::Get()->verbosity); - DeviceUVector uvec1{16}; - ASSERT_EQ(uvec1.size(), 16); - uvec1.resize(3); - ASSERT_EQ(uvec1.size(), 3); - ASSERT_EQ(uvec1.Capacity(), 16); - ASSERT_EQ(std::distance(uvec1.begin(), uvec1.end()), uvec1.size()); - auto orig = uvec1.size(); - - thrust::sequence(dh::CachingThrustPolicy(), uvec1.begin(), uvec1.end(), 0); - uvec1.resize(32); - ASSERT_EQ(uvec1.size(), 32); - ASSERT_EQ(uvec1.Capacity(), 32); - auto eq = thrust::equal(dh::CachingThrustPolicy(), uvec1.cbegin(), uvec1.cbegin() + orig, - thrust::make_counting_iterator(0)); - ASSERT_TRUE(eq); - - uvec1.clear(); - ASSERT_EQ(uvec1.size(), 0); - ASSERT_EQ(uvec1.Capacity(), 32); + { + // Second half of the dynamic table + DeviceUVector uvec{16}; + ASSERT_EQ(uvec.size(), 16); + uvec.resize(13); + ASSERT_EQ(uvec.size(), 13); + ASSERT_EQ(uvec.Capacity(), 16); + ASSERT_EQ(std::distance(uvec.begin(), uvec.end()), uvec.size()); + auto orig = uvec.size(); + + // Grow + thrust::sequence(dh::CachingThrustPolicy(), uvec.begin(), uvec.end(), 0); + uvec.resize(32); + ASSERT_EQ(uvec.size(), 32); + ASSERT_EQ(uvec.Capacity(), 32); + auto eq = thrust::equal(dh::CachingThrustPolicy(), uvec.cbegin(), uvec.cbegin() + orig, + thrust::make_counting_iterator(0)); + ASSERT_TRUE(eq); + + uvec.clear(); + ASSERT_EQ(uvec.size(), 0); + ASSERT_EQ(uvec.Capacity(), 0); + } + + { + // First half of the dynamic table + DeviceUVector uvec2{16}; + ASSERT_EQ(uvec2.Capacity(), 16); + thrust::sequence(dh::CachingThrustPolicy(), uvec2.begin(), uvec2.end(), 0); + std::size_t n = 4; + uvec2.resize(n); + ASSERT_EQ(uvec2.Capacity(), n); + auto eq = thrust::equal(dh::CachingThrustPolicy(), uvec2.cbegin(), + uvec2.cbegin() + uvec2.size(), thrust::make_counting_iterator(0)); + ASSERT_TRUE(eq); + } } #if defined(__linux__)