Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions Src/Base/AMReX_SIMD.H
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,111 @@ namespace amrex::simd
# if __cplusplus >= 202002L
using vir::cvt;
# endif

/** Vectorized ternary operator: select(mask, true_val, false_val)
*
* Selects elements from true_val where mask is true and from false_val
* where mask is false. Analogous to (mask ? true_val : false_val) for
* scalars.
*
* Note: both true_val and false_val are eagerly evaluated (function
* arguments). To guard against operations like division by zero,
* sanitize inputs before the operation rather than relying on
* conditional selection.
*
* Example:
* ```cpp
* template <typename T>
* T compute (T const& a, T const& b)
* {
* auto safe_b = amrex::simd::stdx::select(b != T(0), b, T(1));
* return amrex::simd::stdx::select(b != T(0), a / safe_b, T(0));
* }
* ```
*
* @see C++26 std::simd select
*
* @todo Remove when SIMD provider (vir-simd / C++26) provides select.
* https://github.com/mattkretz/vir-simd/issues/49
*/
template <typename T, typename Abi>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
vir::stdx::simd<T, Abi> select (
typename vir::stdx::simd<T, Abi>::mask_type const& mask,
vir::stdx::simd<T, Abi> const& true_val,
vir::stdx::simd<T, Abi> const& false_val)
{
vir::stdx::simd<T, Abi> result = false_val;
where(mask, result) = true_val;
return result;
}
#else
// fallback implementations for functions that are commonly used in portable code paths

/** True if the boolean value is true (scalar identity fallback for simd any_of)
*
* Example:
* ```cpp
* // Works for both simd_mask and scalar bool:
* auto mask = a > b;
* if (amrex::simd::stdx::any_of(mask)) { ... }
* ```
*
* @see https://en.cppreference.com/w/cpp/experimental/simd/any_of
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
bool any_of (bool const v) { return v; }

/// \cond DOXYGEN_IGNORE
namespace detail {
template <typename T>
struct where_expression {
bool mask;
T* value;

AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
where_expression& operator= (T const& new_val)
{
if (mask) { *value = new_val; }
return *this;
}
};
}
/// \endcond

/** Masked assignment expression (scalar fallback for simd where)
*
* Returns an expression object whose assignment operator conditionally
* updates value only when mask is true.
*
* Example:
* ```cpp
* // Works for both simd<T> and scalar T:
* auto mask = b > T(0);
* T result = T(0);
* amrex::simd::stdx::where(mask, result) = a / b;
* ```
*
* @see https://en.cppreference.com/w/cpp/experimental/simd/where
*/
template <typename T>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
detail::where_expression<T> where (bool const mask, T& value)
{
return {mask, &value};
}

/** Vectorized ternary operator (scalar fallback for simd select)
*
* @see select in the AMREX_USE_SIMD path above
* @see C++26 std::simd select
*/
template <typename T>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
T select (bool const mask, T const& true_val, T const& false_val)
{
return mask ? true_val : false_val;
}
#endif
}

Expand Down
44 changes: 44 additions & 0 deletions Tests/SIMD/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,50 @@ int main (int argc, char* argv[])
<< (err == 0 ? "PASSED" : "FAILED") << "\n";
}

// ================================================================
// Test 14: any_of, where, select — portable single-source
// Uses SIMDParticleReal<>, which is a SIMD vector when AMREX_USE_SIMD=ON
// and a plain scalar when OFF. The same code path exercises
// both the real SIMD and the scalar fallback implementations.
// ================================================================
{
using PReal_t = simd::SIMDParticleReal<>;

// safe reciprocal: 1/b where b != 0, else 0
auto b = PReal_t(2);
auto mask = b != PReal_t(0);
auto safe_b = simd::stdx::select(mask, b, PReal_t(1));
auto recip = simd::stdx::select(mask,
PReal_t(1) / safe_b,
PReal_t(0));

// any_of: at least one lane should be nonzero
AMREX_ALWAYS_ASSERT(simd::stdx::any_of(mask));

// where: masked assignment
auto acc = PReal_t(0);
simd::stdx::where(mask, acc) = recip;

// verify: b=2 everywhere → recip=0.5, acc=0.5
int err = 0;
auto check = [&] (ParticleReal got, ParticleReal expected) {
if (std::abs(got - expected) > ParticleReal(1.e-10)) { ++err; }
};
#ifdef AMREX_USE_SIMD
for (int lane = 0; lane < static_cast<int>(PReal_t::size()); ++lane) {
check(recip[lane], ParticleReal(0.5));
check(acc[lane], ParticleReal(0.5));
}
#else
check(recip, ParticleReal(0.5));
check(acc, ParticleReal(0.5));
#endif

nerrors += err;
Print() << "any_of + where + select (portable): "
<< (err == 0 ? "PASSED" : "FAILED") << "\n";
}

// ================================================================
// Final report
// ================================================================
Expand Down
Loading