@@ -174,62 +174,94 @@ class CopyableCudaEvent {
174174// AtomicEvent implementations
175175// /////////////////////////////////////////////////////////////////////////////
176176
177- __host__ __device__ void event_wait (AtomicEvent::Atomic* ac, uint64_t value) {
178- uint64_t current;
179- while ((current = ac->load ()) < value) {
180- ac->wait (current);
177+ __host__ __device__ void event_wait (uint32_t * ptr, uint32_t value) {
178+ cuda::atomic_ref<uint32_t > ac (*ptr);
179+ uint32_t current;
180+ while ((current = ac.load ()) < value) {
181+ ac.wait (current);
181182 }
182183}
183184
184- __host__ __device__ void event_signal (AtomicEvent::Atomic* ac, uint64_t value) {
185- ac->store (value);
186- ac->notify_all ();
185+ __host__ __device__ void event_signal (uint32_t * ptr, uint32_t value) {
186+ cuda::atomic_ref<uint32_t > ac (*ptr);
187+ ac.store (value);
188+ ac.notify_all ();
187189}
188190
189- __global__ void event_wait_kernel (AtomicEvent::Atomic* ac, uint64_t value) {
190- event_wait (ac , value);
191+ __global__ void event_wait_kernel (uint32_t * ptr, uint32_t value) {
192+ event_wait (ptr , value);
191193}
192194
193- __global__ void event_signal_kernel (AtomicEvent::Atomic* ac, uint64_t value) {
194- event_signal (ac, value);
195+ __global__ void event_signal_kernel (uint32_t * ptr, uint32_t value) {
196+ __threadfence_system ();
197+ event_signal (ptr, value);
198+ __threadfence_system ();
195199}
196200
197- bool supports_concurrent_managed_access () {
198- static bool concurrent_managed_access = []() {
201+ auto check_gpu_coherency () {
202+ static auto coherency = []() {
199203 int device_count = gpu::device_count ();
204+ bool concurrent_managed_access = true ;
205+ bool host_native_atomic = true ;
200206 for (int i = 0 ; i < device_count; ++i) {
201- if (! cu::device (i). concurrent_managed_access ()) {
202- return false ;
203- }
207+ auto & d = cu::device (i);
208+ concurrent_managed_access &= d. concurrent_managed_access () ;
209+ host_native_atomic &= d. host_native_atomic ();
204210 }
205- return true ;
211+ return std::make_tuple (concurrent_managed_access, host_native_atomic) ;
206212 }();
207- return concurrent_managed_access ;
213+ return coherency ;
208214}
209215
210216AtomicEvent::AtomicEvent () {
211- if (!supports_concurrent_managed_access ()) {
212- throw std::runtime_error (
213- " Device does not support synchronization in managed memory." );
217+ void * buf;
218+ cudaError_t (*cuda_free)(void *);
219+ // There are 3 kinds of systems we are implementing for:
220+ // 1. concurrentManagedAccess == true
221+ // => use cuda::atom_ref on managed memory
222+ // 2. hostNativeAtomicSupported == true
223+ // => use cuda::atom_ref on pinned host memory
224+ // 2. no hardware cpu/gpu coherency
225+ // => use cuda::atom_ref on device memory
226+ auto [concurrent_managed_access, host_native_atomic] = check_gpu_coherency ();
227+ if (concurrent_managed_access) {
228+ CHECK_CUDA_ERROR (cudaMallocManaged (&buf, sizeof (uint32_t )));
229+ cuda_free = cudaFree;
230+ coherent_ = true ;
231+ } else if (host_native_atomic) {
232+ CHECK_CUDA_ERROR (cudaMallocHost (&buf, sizeof (uint32_t )));
233+ cuda_free = cudaFreeHost;
234+ coherent_ = true ;
235+ } else {
236+ CHECK_CUDA_ERROR (cudaMalloc (&buf, sizeof (uint32_t )));
237+ cuda_free = cudaFree;
238+ coherent_ = false ;
239+ }
240+ buf_ = std::shared_ptr<void >(
241+ buf, [cuda_free](void * buf) { CHECK_CUDA_ERROR (cuda_free (buf)); });
242+ if (coherent_) {
243+ *ptr () = 0 ;
244+ } else {
245+ CHECK_CUDA_ERROR (cudaMemset (buf, 0 , sizeof (uint32_t )));
214246 }
215- buf_ = std::shared_ptr<Buffer>(
216- new Buffer{allocator ().malloc (sizeof (Atomic))}, [](Buffer* ptr) {
217- allocator ().free (*ptr);
218- delete ptr;
219- });
220- *static_cast <uint64_t *>(buf_->raw_ptr ()) = 0 ;
221247}
222248
223- void AtomicEvent::wait (uint64_t value) {
249+ void AtomicEvent::wait (uint32_t value) {
224250 nvtx3::scoped_range r (" cu::AtomicEvent::wait" );
225- event_wait (atomic (), value);
251+ if (coherent_) {
252+ event_wait (ptr (), value);
253+ } else {
254+ while (!is_signaled (value)) {
255+ std::this_thread::yield ();
256+ }
257+ }
226258}
227259
228- void AtomicEvent::wait (cudaStream_t stream, uint64_t value) {
229- event_wait_kernel<<<1 , 1 , 0 , stream>>> (atomic (), value);
260+ void AtomicEvent::wait (cudaStream_t stream, uint32_t value) {
261+ event_wait_kernel<<<1 , 1 , 0 , stream>>> (ptr (), value);
230262}
231263
232- void AtomicEvent::wait (Stream s, uint64_t value) {
264+ void AtomicEvent::wait (Stream s, uint32_t value) {
233265 nvtx3::scoped_range r (" cu::AtomicEvent::wait(s)" );
234266 if (s.device == mlx::core::Device::cpu) {
235267 scheduler::enqueue (s, [*this , value]() mutable { wait (value); });
@@ -241,22 +273,26 @@ void AtomicEvent::wait(Stream s, uint64_t value) {
241273 }
242274}
243275
244- void AtomicEvent::signal (uint64_t value) {
276+ void AtomicEvent::signal (uint32_t value) {
245277 nvtx3::scoped_range r (" cu::AtomicEvent::signal" );
246- event_signal (atomic (), value);
278+ if (coherent_) {
279+ event_signal (ptr (), value);
280+ } else {
281+ signal (signal_stream (), value);
282+ }
247283}
248284
249- void AtomicEvent::signal (cudaStream_t stream, uint64_t value) {
250- event_signal_kernel<<<1 , 1 , 0 , stream>>> (atomic (), value);
285+ void AtomicEvent::signal (cudaStream_t stream, uint32_t value) {
286+ event_signal_kernel<<<1 , 1 , 0 , stream>>> (ptr (), value);
251287}
252288
253- void AtomicEvent::signal (Stream s, uint64_t value) {
289+ void AtomicEvent::signal (Stream s, uint32_t value) {
254290 nvtx3::scoped_range r (" cu::AtomicEvent::signal(s)" );
255291 if (s.device == mlx::core::Device::cpu) {
256292 // Signal through a GPU stream so the atomic is updated in GPU - updating
257293 // the atomic in CPU sometimes does not get GPU notified.
258- static CudaStream stream ( device (mlx::core::Device::gpu));
259- scheduler::enqueue ( s, [*this , value]() mutable { signal (stream , value); });
294+ scheduler::enqueue (
295+ s, [*this , value]() mutable { signal (signal_stream () , value); });
260296 } else {
261297 auto & encoder = get_command_encoder (s);
262298 encoder.commit ();
@@ -265,14 +301,26 @@ void AtomicEvent::signal(Stream s, uint64_t value) {
265301 }
266302}
267303
268- bool AtomicEvent::is_signaled (uint64_t value) const {
269- nvtx3::scoped_range r (" cu::AtomicEvent::is_signaled" );
270- return atomic ()->load () >= value;
304+ bool AtomicEvent::is_signaled (uint32_t val) const {
305+ return value () >= val;
271306}
272307
273- uint64_t AtomicEvent::value () const {
308+ uint32_t AtomicEvent::value () const {
274309 nvtx3::scoped_range r (" cu::AtomicEvent::value" );
275- return atomic ()->load ();
310+ if (coherent_) {
311+ cuda::atomic_ref<uint32_t > ac (*ptr ());
312+ return ac.load ();
313+ } else {
314+ uint32_t val;
315+ CHECK_CUDA_ERROR (
316+ cudaMemcpy (&val, ptr (), sizeof (uint32_t ), cudaMemcpyDeviceToHost));
317+ return val;
318+ }
319+ }
320+
321+ const CudaStream& AtomicEvent::signal_stream () {
322+ static CudaStream stream (device (0 ));
323+ return stream;
276324}
277325
278326} // namespace cu
0 commit comments