Skip to content

Commit 2ac18ed

Browse files
authored
[CUDA] Fallback Event impl when there is no hardware cpu/gpu coherency (#3070)
1 parent b537b36 commit 2ac18ed

File tree

4 files changed

+124
-57
lines changed

4 files changed

+124
-57
lines changed

mlx/backend/cuda/device.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ Device::Device(int device) : device_(device) {
4242
&concurrent_managed_access_,
4343
cudaDevAttrConcurrentManagedAccess,
4444
device_));
45+
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
46+
&host_native_atomic_, cudaDevAttrHostNativeAtomicSupported, device_));
47+
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
48+
&managed_memory_, cudaDevAttrManagedMemory, device_));
49+
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
50+
&memory_pools_, cudaDevAttrMemoryPoolsSupported, device_));
4551
}
4652

4753
Device::~Device() {

mlx/backend/cuda/device.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,24 @@ class Device {
164164
bool concurrent_managed_access() const {
165165
return concurrent_managed_access_ == 1;
166166
}
167+
bool host_native_atomic() const {
168+
return host_native_atomic_ == 1;
169+
}
170+
bool managed_memory() const {
171+
return managed_memory_ == 1;
172+
}
173+
bool memory_pools() const {
174+
return memory_pools_ == 1;
175+
}
167176

168177
private:
169178
int device_;
170179
int compute_capability_major_;
171180
int compute_capability_minor_;
172181
int concurrent_managed_access_;
182+
int host_native_atomic_;
183+
int managed_memory_;
184+
int memory_pools_;
173185
std::string device_name_;
174186
cublasLtHandle_t cublaslt_handle_{nullptr};
175187
cudnnHandle_t cudnn_handle_{nullptr};

mlx/backend/cuda/event.cu

Lines changed: 92 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -174,62 +174,94 @@ class CopyableCudaEvent {
174174
// AtomicEvent implementations
175175
///////////////////////////////////////////////////////////////////////////////
176176

177-
__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
178-
uint64_t current;
179-
while ((current = ac->load()) < value) {
180-
ac->wait(current);
177+
__host__ __device__ void event_wait(uint32_t* ptr, uint32_t value) {
178+
cuda::atomic_ref<uint32_t> ac(*ptr);
179+
uint32_t current;
180+
while ((current = ac.load()) < value) {
181+
ac.wait(current);
181182
}
182183
}
183184

184-
__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
185-
ac->store(value);
186-
ac->notify_all();
185+
__host__ __device__ void event_signal(uint32_t* ptr, uint32_t value) {
186+
cuda::atomic_ref<uint32_t> ac(*ptr);
187+
ac.store(value);
188+
ac.notify_all();
187189
}
188190

189-
__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
190-
event_wait(ac, value);
191+
__global__ void event_wait_kernel(uint32_t* ptr, uint32_t value) {
192+
event_wait(ptr, value);
191193
}
192194

193-
__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
194-
event_signal(ac, value);
195+
__global__ void event_signal_kernel(uint32_t* ptr, uint32_t value) {
196+
__threadfence_system();
197+
event_signal(ptr, value);
198+
__threadfence_system();
195199
}
196200

197-
bool supports_concurrent_managed_access() {
198-
static bool concurrent_managed_access = []() {
201+
auto check_gpu_coherency() {
202+
static auto coherency = []() {
199203
int device_count = gpu::device_count();
204+
bool concurrent_managed_access = true;
205+
bool host_native_atomic = true;
200206
for (int i = 0; i < device_count; ++i) {
201-
if (!cu::device(i).concurrent_managed_access()) {
202-
return false;
203-
}
207+
auto& d = cu::device(i);
208+
concurrent_managed_access &= d.concurrent_managed_access();
209+
host_native_atomic &= d.host_native_atomic();
204210
}
205-
return true;
211+
return std::make_tuple(concurrent_managed_access, host_native_atomic);
206212
}();
207-
return concurrent_managed_access;
213+
return coherency;
208214
}
209215

210216
AtomicEvent::AtomicEvent() {
211-
if (!supports_concurrent_managed_access()) {
212-
throw std::runtime_error(
213-
"Device does not support synchronization in managed memory.");
217+
void* buf;
218+
cudaError_t (*cuda_free)(void*);
219+
// There are 3 kinds of systems we are implementing for:
220+
// 1. concurrentManagedAccess == true
221+
// => use cuda::atom_ref on managed memory
222+
// 2. hostNativeAtomicSupported == true
223+
// => use cuda::atom_ref on pinned host memory
224+
// 2. no hardware cpu/gpu coherency
225+
// => use cuda::atom_ref on device memory
226+
auto [concurrent_managed_access, host_native_atomic] = check_gpu_coherency();
227+
if (concurrent_managed_access) {
228+
CHECK_CUDA_ERROR(cudaMallocManaged(&buf, sizeof(uint32_t)));
229+
cuda_free = cudaFree;
230+
coherent_ = true;
231+
} else if (host_native_atomic) {
232+
CHECK_CUDA_ERROR(cudaMallocHost(&buf, sizeof(uint32_t)));
233+
cuda_free = cudaFreeHost;
234+
coherent_ = true;
235+
} else {
236+
CHECK_CUDA_ERROR(cudaMalloc(&buf, sizeof(uint32_t)));
237+
cuda_free = cudaFree;
238+
coherent_ = false;
239+
}
240+
buf_ = std::shared_ptr<void>(
241+
buf, [cuda_free](void* buf) { CHECK_CUDA_ERROR(cuda_free(buf)); });
242+
if (coherent_) {
243+
*ptr() = 0;
244+
} else {
245+
CHECK_CUDA_ERROR(cudaMemset(buf, 0, sizeof(uint32_t)));
214246
}
215-
buf_ = std::shared_ptr<Buffer>(
216-
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
217-
allocator().free(*ptr);
218-
delete ptr;
219-
});
220-
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
221247
}
222248

223-
void AtomicEvent::wait(uint64_t value) {
249+
void AtomicEvent::wait(uint32_t value) {
224250
nvtx3::scoped_range r("cu::AtomicEvent::wait");
225-
event_wait(atomic(), value);
251+
if (coherent_) {
252+
event_wait(ptr(), value);
253+
} else {
254+
while (!is_signaled(value)) {
255+
std::this_thread::yield();
256+
}
257+
}
226258
}
227259

228-
void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
229-
event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
260+
void AtomicEvent::wait(cudaStream_t stream, uint32_t value) {
261+
event_wait_kernel<<<1, 1, 0, stream>>>(ptr(), value);
230262
}
231263

232-
void AtomicEvent::wait(Stream s, uint64_t value) {
264+
void AtomicEvent::wait(Stream s, uint32_t value) {
233265
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
234266
if (s.device == mlx::core::Device::cpu) {
235267
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
@@ -241,22 +273,26 @@ void AtomicEvent::wait(Stream s, uint64_t value) {
241273
}
242274
}
243275

244-
void AtomicEvent::signal(uint64_t value) {
276+
void AtomicEvent::signal(uint32_t value) {
245277
nvtx3::scoped_range r("cu::AtomicEvent::signal");
246-
event_signal(atomic(), value);
278+
if (coherent_) {
279+
event_signal(ptr(), value);
280+
} else {
281+
signal(signal_stream(), value);
282+
}
247283
}
248284

249-
void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
250-
event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
285+
void AtomicEvent::signal(cudaStream_t stream, uint32_t value) {
286+
event_signal_kernel<<<1, 1, 0, stream>>>(ptr(), value);
251287
}
252288

253-
void AtomicEvent::signal(Stream s, uint64_t value) {
289+
void AtomicEvent::signal(Stream s, uint32_t value) {
254290
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
255291
if (s.device == mlx::core::Device::cpu) {
256292
// Signal through a GPU stream so the atomic is updated in GPU - updating
257293
// the atomic in CPU sometimes does not get GPU notified.
258-
static CudaStream stream(device(mlx::core::Device::gpu));
259-
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
294+
scheduler::enqueue(
295+
s, [*this, value]() mutable { signal(signal_stream(), value); });
260296
} else {
261297
auto& encoder = get_command_encoder(s);
262298
encoder.commit();
@@ -265,14 +301,26 @@ void AtomicEvent::signal(Stream s, uint64_t value) {
265301
}
266302
}
267303

268-
bool AtomicEvent::is_signaled(uint64_t value) const {
269-
nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
270-
return atomic()->load() >= value;
304+
bool AtomicEvent::is_signaled(uint32_t val) const {
305+
return value() >= val;
271306
}
272307

273-
uint64_t AtomicEvent::value() const {
308+
uint32_t AtomicEvent::value() const {
274309
nvtx3::scoped_range r("cu::AtomicEvent::value");
275-
return atomic()->load();
310+
if (coherent_) {
311+
cuda::atomic_ref<uint32_t> ac(*ptr());
312+
return ac.load();
313+
} else {
314+
uint32_t val;
315+
CHECK_CUDA_ERROR(
316+
cudaMemcpy(&val, ptr(), sizeof(uint32_t), cudaMemcpyDeviceToHost));
317+
return val;
318+
}
319+
}
320+
321+
const CudaStream& AtomicEvent::signal_stream() {
322+
static CudaStream stream(device(0));
323+
return stream;
276324
}
277325

278326
} // namespace cu

mlx/backend/cuda/event.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,26 @@ class CudaEvent {
5454
// CudaEvent so the latter should always be preferred when possible.
5555
class AtomicEvent {
5656
public:
57-
using Atomic = cuda::atomic<uint64_t>;
58-
5957
AtomicEvent();
6058

61-
void wait(uint64_t value);
62-
void wait(cudaStream_t stream, uint64_t value);
63-
void wait(Stream s, uint64_t value);
64-
void signal(uint64_t value);
65-
void signal(cudaStream_t stream, uint64_t value);
66-
void signal(Stream s, uint64_t value);
67-
bool is_signaled(uint64_t value) const;
68-
uint64_t value() const;
59+
void wait(uint32_t value);
60+
void wait(cudaStream_t stream, uint32_t value);
61+
void wait(Stream s, uint32_t value);
62+
void signal(uint32_t value);
63+
void signal(cudaStream_t stream, uint32_t value);
64+
void signal(Stream s, uint32_t value);
65+
bool is_signaled(uint32_t value) const;
66+
uint32_t value() const;
6967

7068
private:
71-
Atomic* atomic() const {
72-
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
69+
const CudaStream& signal_stream();
70+
71+
uint32_t* ptr() const {
72+
return static_cast<uint32_t*>(buf_.get());
7373
}
7474

75-
std::shared_ptr<allocator::Buffer> buf_;
75+
bool coherent_;
76+
std::shared_ptr<void> buf_;
7677
};
7778

7879
} // namespace mlx::core::cu

0 commit comments

Comments
 (0)