30#ifndef SACADO_FAD_EXP_ATOMIC_HPP
31#define SACADO_FAD_EXP_ATOMIC_HPP
34#if defined(HAVE_SACADO_KOKKOSCORE)
37#include "Kokkos_Atomic.hpp"
38#include "impl/Kokkos_Error.hpp"
46 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
48 void atomic_add(ViewFadPtr<ValT,sl,ss,U> dst,
const Expr<T>& xx) {
49 using Kokkos::atomic_add;
53 const int xsz =
x.size();
54 const int sz = dst->size();
60 "Sacado error: Fad resize within atomic_add() not supported!");
62 if (xsz != sz && sz > 0 && xsz > 0)
64 "Sacado error: Fad assignment of incompatiable sizes!");
67 if (sz > 0 && xsz > 0) {
69 atomic_add(&(dst->fastAccessDx(
i)),
x.fastAccessDx(
i));
72 atomic_add(&(dst->val()),
x.val());
78 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
80 atomic_oper_fetch_host(
const Oper& op, DestPtrT dest, ValT* dest_val,
86#ifdef KOKKOS_INTERNAL_NOT_PARALLEL
87 auto scope = desul::MemoryScopeCaller();
89 auto scope = desul::MemoryScopeDevice();
92 while (!desul::Impl::lock_address((
void*)dest_val, scope))
94 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
95 return_type return_val = op.apply(*dest,
val);
97 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
98 desul::Impl::unlock_address((
void*)dest_val, scope);
102 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
104 atomic_fetch_oper_host(
const Oper& op, DestPtrT dest, ValT* dest_val,
110#ifdef KOKKOS_INTERNAL_NOT_PARALLEL
111 auto scope = desul::MemoryScopeCaller();
113 auto scope = desul::MemoryScopeDevice();
116 while (!desul::Impl::lock_address((
void*)dest_val, scope))
118 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
119 return_type return_val = *dest;
120 *dest = op.apply(return_val,
val);
121 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
122 desul::Impl::unlock_address((
void*)dest_val, scope);
127#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
129 inline bool atomics_use_team() {
130#if defined(SACADO_VIEW_CUDA_HIERARCHICAL) || defined(SACADO_VIEW_CUDA_HIERARCHICAL_DFAD)
135 return (blockDim.x > 1);
142#if defined(KOKKOS_ENABLE_CUDA)
146 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
149 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
155 auto scope = desul::MemoryScopeDevice();
157 if (atomics_use_team()) {
160 if (threadIdx.x == 0)
161 go = !desul::Impl::lock_address_cuda((
void*)dest_val, scope);
162 go = Kokkos::shfl(go, 0, blockDim.x);
164 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
165 return_type return_val = op.apply(*dest,
val);
167 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
168 if (threadIdx.x == 0)
169 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
173 return_type return_val;
176 unsigned int mask = __activemask() ;
177 unsigned int active = __ballot_sync(mask, 1);
178 unsigned int done_active = 0;
179 while (active != done_active) {
181 if (desul::Impl::lock_address_cuda((
void*)dest_val, scope)) {
182 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
183 return_val = op.apply(*dest,
val);
185 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
186 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
190 done_active = __ballot_sync(mask, done);
196 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
199 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
205 auto scope = desul::MemoryScopeDevice();
207 if (atomics_use_team()) {
210 if (threadIdx.x == 0)
211 go = !desul::Impl::lock_address_cuda((
void*)dest_val, scope);
212 go = Kokkos::shfl(go, 0, blockDim.x);
214 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
215 return_type return_val = *dest;
216 *dest = op.apply(return_val,
val);
217 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
218 if (threadIdx.x == 0)
219 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
223 return_type return_val;
226 unsigned int mask = __activemask() ;
227 unsigned int active = __ballot_sync(mask, 1);
228 unsigned int done_active = 0;
229 while (active != done_active) {
231 if (desul::Impl::lock_address_cuda((
void*)dest_val, scope)) {
232 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
234 *dest = op.apply(return_val,
val);
235 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
236 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
240 done_active = __ballot_sync(mask, done);
246#elif defined(KOKKOS_ENABLE_HIP)
250 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
253 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
259 auto scope = desul::MemoryScopeDevice();
261 if (atomics_use_team()) {
264 if (threadIdx.x == 0)
265 go = !desul::Impl::lock_address_hip((
void*)dest_val, scope);
266 go = Kokkos::Experimental::shfl(go, 0, blockDim.x);
268 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
269 return_type return_val = op.apply(*dest,
val);
271 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
272 if (threadIdx.x == 0)
273 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
277 return_type return_val;
279 unsigned int active = __ballot(1);
280 unsigned int done_active = 0;
281 while (active != done_active) {
283 if (desul::Impl::lock_address_hip((
void*)dest_val, scope)) {
284 return_val = op.apply(*dest,
val);
286 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
290 done_active = __ballot(done);
296 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
299 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
305 auto scope = desul::MemoryScopeDevice();
307 if (atomics_use_team()) {
310 if (threadIdx.x == 0)
311 go = !desul::Impl::lock_address_hip((
void*)dest_val, scope);
312 go = Kokkos::Experimental::shfl(go, 0, blockDim.x);
314 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
315 return_type return_val = *dest;
316 *dest = op.apply(return_val,
val);
317 desul:atomic_thread_fence(desul::MemoryOrderRelease(), scope);
318 if (threadIdx.x == 0)
319 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
323 return_type return_val;
325 unsigned int active = __ballot(1);
326 unsigned int done_active = 0;
327 while (active != done_active) {
329 if (desul::Impl::lock_address_hip((
void*)dest_val, scope)) {
331 *dest = op.apply(return_val,
val);
332 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
336 done_active = __ballot(done);
346 template <
typename Oper,
typename S>
348 atomic_oper_fetch(
const Oper& op, GeneralFad<S>* dest,
349 const GeneralFad<S>&
val)
351 KOKKOS_IF_ON_HOST(
return Impl::atomic_oper_fetch_host(op, dest, &(dest->val()),
val);)
352 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_oper_fetch_device(op, dest, &(dest->val()),
val);)
354 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
355 typename U,
typename T>
357 atomic_oper_fetch(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
360 KOKKOS_IF_ON_HOST(
return Impl::atomic_oper_fetch_host(op, dest, &dest.val(),
val);)
361 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_oper_fetch_device(op, dest, &dest.val(),
val);)
364 template <
typename Oper,
typename S>
366 atomic_fetch_oper(
const Oper& op, GeneralFad<S>* dest,
367 const GeneralFad<S>&
val)
369 KOKKOS_IF_ON_HOST(
return Impl::atomic_fetch_oper_host(op, dest, &(dest->val()),
val);)
370 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_fetch_oper_device(op, dest, &(dest->val()),
val);)
372 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
373 typename U,
typename T>
375 atomic_fetch_oper(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
378 KOKKOS_IF_ON_HOST(
return Impl::atomic_fetch_oper_host(op, dest, &dest.val(),
val);)
379 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_fetch_oper_device(op, dest, &dest.val(),
val);)
384 template <
class Scalar1,
class Scalar2>
385 KOKKOS_FORCEINLINE_FUNCTION
386 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
387 ->
decltype(
max(val1,val2))
389 return max(val1,val2);
393 template <
class Scalar1,
class Scalar2>
394 KOKKOS_FORCEINLINE_FUNCTION
395 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
396 ->
decltype(
min(val1,val2))
398 return min(val1,val2);
402 template <
class Scalar1,
class Scalar2>
403 KOKKOS_FORCEINLINE_FUNCTION
404 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
405 ->
decltype(val1+val2)
411 template <
class Scalar1,
class Scalar2>
412 KOKKOS_FORCEINLINE_FUNCTION
413 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
414 ->
decltype(val1-val2)
420 template <
class Scalar1,
class Scalar2>
421 KOKKOS_FORCEINLINE_FUNCTION
422 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
423 ->
decltype(val1*val2)
429 template <
class Scalar1,
class Scalar2>
430 KOKKOS_FORCEINLINE_FUNCTION
431 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
432 ->
decltype(val1/val2)
442 template <
typename S>
444 atomic_max_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
445 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest,
val);
447 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
449 atomic_max_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
450 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest,
val);
452 template <
typename S>
454 atomic_min_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
455 return Impl::atomic_oper_fetch(Impl::MinOper(), dest,
val);
457 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
459 atomic_min_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
460 return Impl::atomic_oper_fetch(Impl::MinOper(), dest,
val);
462 template <
typename S>
464 atomic_add_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
465 return Impl::atomic_oper_fetch(Impl::AddOper(), dest,
val);
467 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
469 atomic_add_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
470 return Impl::atomic_oper_fetch(Impl::AddOper(), dest,
val);
472 template <
typename S>
474 atomic_sub_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
475 return Impl::atomic_oper_fetch(Impl::SubOper(), dest,
val);
477 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
479 atomic_sub_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
480 return Impl::atomic_oper_fetch(Impl::SubOper(), dest,
val);
482 template <
typename S>
484 atomic_mul_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
485 return atomic_oper_fetch(Impl::MulOper(), dest,
val);
487 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
489 atomic_mul_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
490 return Impl::atomic_oper_fetch(Impl::MulOper(), dest,
val);
492 template <
typename S>
494 atomic_div_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
495 return Impl::atomic_oper_fetch(Impl::DivOper(), dest,
val);
497 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
499 atomic_div_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
500 return Impl::atomic_oper_fetch(Impl::DivOper(), dest,
val);
503 template <
typename S>
505 atomic_fetch_max(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
506 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest,
val);
508 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
510 atomic_fetch_max(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
511 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest,
val);
513 template <
typename S>
515 atomic_fetch_min(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
516 return Impl::atomic_fetch_oper(Impl::MinOper(), dest,
val);
518 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
520 atomic_fetch_min(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
521 return Impl::atomic_fetch_oper(Impl::MinOper(), dest,
val);
523 template <
typename S>
525 atomic_fetch_add(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
526 return Impl::atomic_fetch_oper(Impl::AddOper(), dest,
val);
528 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
530 atomic_fetch_add(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
531 return Impl::atomic_fetch_oper(Impl::AddOper(), dest,
val);
533 template <
typename S>
535 atomic_fetch_sub(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
536 return Impl::atomic_fetch_oper(Impl::SubOper(), dest,
val);
538 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
540 atomic_fetch_sub(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
541 return Impl::atomic_fetch_oper(Impl::SubOper(), dest,
val);
543 template <
typename S>
545 atomic_fetch_mul(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
546 return Impl::atomic_fetch_oper(Impl::MulOper(), dest,
val);
548 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
550 atomic_fetch_mul(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
551 return Impl::atomic_fetch_oper(Impl::MulOper(), dest,
val);
553 template <
typename S>
555 atomic_fetch_div(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
556 return Impl::atomic_fetch_oper(Impl::DivOper(), dest,
val);
558 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
560 atomic_fetch_div(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
561 return Impl::atomic_fetch_oper(Impl::DivOper(), dest,
val);
#define SACADO_INLINE_FUNCTION
#define SACADO_FAD_THREAD_SINGLE
#define SACADO_FAD_DERIV_LOOP(I, SZ)
T derived_type
Typename of derived object, returned by derived()
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
SimpleFad< ValueT > min(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
Get the base Fad type from a view/expression.