Skip to content

Commit d1be739

Browse files
committed
Refactor cat and chunk of basic_vec
include/ChangeLog: * include/bits/simd_vec.h (_M_concat_data): Use canonical type for scalar to 1-element vector. (_M_assign_from): Refactor to use basic_mask implementation. (_S_concat, _M_chunk): Implement via _M_assign_from. * tests/creation.cpp: New test. * tests/mask.cpp: Move cat/chunk test to creation.cpp.
1 parent 71ad840 commit d1be739

File tree

3 files changed

+169
-112
lines changed

3 files changed

+169
-112
lines changed

include/bits/simd_vec.h

Lines changed: 100 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ namespace std::simd
333333
_M_concat_data() const
334334
{
335335
if constexpr (_S_is_scalar)
336-
return __vec_builtin_type<value_type, 1>{_M_data};
336+
return __vec_builtin_type<__canonical_vec_type_t<value_type>, 1>{_M_data};
337337
else
338338
return _M_data;
339339
}
@@ -595,70 +595,101 @@ namespace std::simd
595595
}
596596
}
597597

598-
template <typename _A0, typename... _As>
598+
template <typename... _As>
599599
[[__gnu__::__always_inline__]]
600600
constexpr void
601-
_M_assign_from(auto _Offset, const basic_vec<value_type, _A0>& __x0,
602-
const basic_vec<value_type, _As>&... __xs)
601+
_M_assign_from(auto _Offset, const basic_vec<value_type, _As>&... __xs)
603602
{
604-
if constexpr (_Offset.value >= _A0::_S_size)
605-
// make the pack as small as possible
606-
_M_assign_from(integral_constant<int, _Offset.value - _A0::_S_size>(), __xs...);
607-
else if constexpr (_A0::_S_size >= _S_size + _Offset.value)
608-
{
609-
if constexpr (_S_size == 1)
610-
_M_data = __x0[_Offset];
603+
constexpr int __nargs = sizeof...(_As);
604+
using _A0 = _As...[0];
605+
using _Alast = _As...[__nargs - 1];
606+
const auto& __x0 = __xs...[0];
607+
constexpr int __ninputs = (_As::_S_size + ...) - _Offset.value;
608+
if constexpr (_Offset.value >= _A0::_S_size || __ninputs - _Alast::_S_size >= _S_size)
609+
{ // can drop inputs at the front and/or back of the pack
610+
constexpr int __skip_front = __packs_to_skip_at_front(_Offset.value, _As::_S_size...);
611+
constexpr int __skip_back = __packs_to_skip_at_back(_Offset.value, _S_size,
612+
_As::_S_size...);
613+
static_assert(__skip_front > 0 || __skip_back > 0);
614+
constexpr auto [...__skip] = _IotaArray<__skip_front>;
615+
constexpr auto [...__is] = _IotaArray<__nargs - __skip_front - __skip_back>;
616+
constexpr int __new_offset = _Offset.value - (0 + ... + _As...[__skip]::_S_size);
617+
_M_assign_from(cw<__new_offset>, __xs...[__is + __skip_front]...);
618+
}
619+
else if constexpr (_S_is_scalar)
620+
{ // trivial conversion to one value_type
621+
_M_data = __x0[_Offset.value];
622+
}
623+
else if constexpr (_A0::_S_nreg >= 2 || _Alast::_S_nreg >= 2)
624+
{ // flatten first and/or last multi-register argument
625+
const auto& __xlast = __xs...[__nargs - 1];
626+
constexpr bool __flatten_first = _A0::_S_nreg >= 2;
627+
constexpr bool __flatten_last = __nargs > 1 && _Alast::_S_nreg >= 2;
628+
constexpr auto [...__is] = _IotaArray<__nargs - __flatten_first - __flatten_last>;
629+
if constexpr (__flatten_first && __flatten_last)
630+
_M_assign_from(_Offset, __x0._M_data0, __x0._M_data1, __xs...[__is + 1]...,
631+
__xlast._M_data0, __xlast._M_data1);
632+
else if constexpr (__flatten_first)
633+
_M_assign_from(_Offset, __x0._M_data0, __x0._M_data1, __xs...[__is + 1]...);
611634
else
612-
_M_data = _VecOps<_DataType>::_S_extract(__x0._M_concat_data(), _Offset);
635+
_M_assign_from(_Offset, __xs...[__is]..., __xlast._M_data0, __xlast._M_data1);
636+
}
637+
638+
// at this point __xs should be as small as possible; there may be some corner cases left
639+
640+
else if constexpr (__nargs == 1)
641+
{ // simple and optimal
642+
_M_data = _VecOps<_DataType>::_S_extract(__x0._M_concat_data(), _Offset);
643+
}
644+
else if constexpr (__nargs == 2 && _A0::_S_nreg == 1 && _Alast::_S_nreg == 1)
645+
{ // optimize concat of two input vectors (e.g. using palignr)
646+
constexpr auto [...__is] = _IotaArray<_S_full_size>;
647+
constexpr int __v2_offset = __x0._S_full_size;
648+
_M_data = __builtin_shufflevector(
649+
__x0._M_concat_data(), __xs...[1]._M_concat_data(), [](int __i) consteval {
650+
if (__i < _A0::_S_size)
651+
return __i;
652+
__i -= _A0::_S_size;
653+
if (__i < _Alast::_S_size)
654+
return __i + __v2_offset;
655+
else
656+
return -1;
657+
}(__is + _Offset.value)...);
658+
}
659+
else if (__is_const_known(__xs...) || (_As::_S_size + ...) == _S_size)
660+
{ // hard to optimize for the compiler, but necessary in constant expressions
661+
_M_data = _VecOps<_DataType>::_S_extract(
662+
__vec_concat_sized<__xs.size()...>(__xs._M_concat_data()...),
663+
_Offset);
613664
}
614665
else
615-
_M_data = _VecOps<_DataType>::_S_extract(
616-
__vec_concat_sized<__x0.size(), __xs.size()...>(__x0._M_concat_data(),
617-
__xs._M_concat_data()...),
618-
_Offset);
666+
{ // fallback to concatenation in memory => load the result
667+
alignas(_DataType) value_type
668+
__tmp[std::max((... + _As::_S_size), _Offset.value + _S_full_size)] = {};
669+
int __offset = 0;
670+
template for (const auto& __x : {__xs...})
671+
{
672+
__x._M_store(__tmp + __offset);
673+
__offset += __x.size.value;
674+
}
675+
__builtin_memcpy(&_M_data, __tmp + _Offset.value, sizeof(_M_data));
676+
}
619677
}
620678

621-
template <typename _A0>
622-
[[__gnu__::__always_inline__]]
623-
static constexpr basic_vec
624-
_S_concat(const basic_vec<value_type, _A0>& __x0) noexcept
625-
{ return static_cast<const basic_vec&>(__x0); }
679+
[[__gnu__::__always_inline__]]
680+
static constexpr basic_vec
681+
_S_concat(const basic_vec& __x0) noexcept
682+
{ return __x0; }
626683

627684
template <typename... _As>
628685
requires (sizeof...(_As) > 1)
629686
[[__gnu__::__always_inline__]]
630687
static constexpr basic_vec
631688
_S_concat(const basic_vec<value_type, _As>&... __xs) noexcept
632689
{
633-
using _A0 = _As...[0];
634-
using _A1 = _As...[1];
635-
if constexpr (!_S_is_partial && _A0::_S_size == _A1::_S_size)
636-
// simple power-of-2 concat
637-
return basic_vec::_S_init(__vec_concat(__xs._M_concat_data()...));
638-
else
639-
{
640-
#if VIR_EXTENSIONS && 0
641-
constexpr bool __simple_inserts
642-
= sizeof...(_As) == 2 && _A1::_S_size <= 2
643-
&& is_same_v<_DataType, typename basic_vec<value_type, _A0>::_DataType>;
644-
// TODO: sometimes concats can be better. But the conditions here are not sufficient.
645-
if (!__builtin_is_constant_evaluated() && __simple_inserts)
646-
{
647-
if constexpr (__simple_inserts)
648-
{
649-
const auto& __x0 = __xs...[0];
650-
const auto& __x1 = __xs...[1];
651-
basic_vec __r;
652-
__r._M_data = __x0._M_data;
653-
template for (int __i : _IotaArray<_A1::_S_size>)
654-
__r._M_data[_A0::_S_size + __i] = __x1[__i];
655-
return __r;
656-
}
657-
}
658-
#endif
659-
return basic_vec::_S_init(__vec_concat_sized<_As::_S_size...>(
660-
__xs._M_concat_data()...));
661-
}
690+
basic_vec __r;
691+
__r._M_assign_from(cw<0>, __xs...);
692+
return __r;
662693
}
663694

664695
[[__gnu__::__always_inline__]]
@@ -1813,50 +1844,35 @@ namespace std::simd
18131844
}
18141845
else if constexpr (__rem == 0)
18151846
{
1816-
using _Rp = array<_Vp, __n>;
1817-
if constexpr (sizeof(_Rp) == sizeof(*this))
1818-
{
1819-
static_assert(!_Vp::_S_is_partial);
1820-
return __builtin_bit_cast(_Rp, *this);
1821-
}
1822-
else
1823-
{
1824-
constexpr auto [...__is] = _IotaArray<__n>;
1825-
return _Rp {_Vp([&](int __i) { return (*this)[__i + __is * _Vp::_S_size]; })...};
1826-
}
1847+
array<_Vp, __n> __r;
1848+
template for (constexpr int __i : _IotaArray<__n>)
1849+
__r[__i]._M_assign_from(cw<_Vp::_S_size * __i>, _M_data0, _M_data1);
1850+
return __r;
18271851
}
18281852
else
18291853
{
1830-
constexpr auto [...__is] = _IotaArray<__n>;
1854+
constexpr auto [...__is] = _IotaArray<__n + 1>;
18311855
using _Rest = resize_t<__rem, _Vp>;
1832-
// can't bit-cast because the member order of tuple is reversed
1833-
return tuple(_Vp ([&](int __i) { return (*this)[__i + __is * _Vp::_S_size]; })...,
1834-
_Rest([&](int __i) { return (*this)[__i + __n * _Vp::_S_size]; }));
1856+
tuple<conditional_t<(__is < __n), _Vp, _Rest>...> __r;
1857+
template for (constexpr int __i : _IotaArray<__n + 1>)
1858+
std::get<__i>(__r)._M_assign_from(cw<_Vp::_S_size * __i>, _M_data0, _M_data1);
1859+
return __r;
18351860
}
18361861
}
18371862

1838-
template <typename _A0, typename... _As>
1863+
template <typename... _As>
18391864
[[__gnu__::__always_inline__]]
18401865
constexpr void
1841-
_M_assign_from(auto _Offset, const basic_vec<value_type, _A0>& __x0,
1842-
const basic_vec<value_type, _As>&... __xs)
1866+
_M_assign_from(auto _Offset, const basic_vec<value_type, _As>&... __xs)
18431867
{
1844-
if constexpr (_Offset.value >= _A0::_S_size)
1845-
// make the pack as small as possible
1846-
_M_assign_from(integral_constant<int, _Offset.value - _A0::_S_size>(), __xs...);
1847-
else
1848-
{
1849-
_M_data0._M_assign_from(_Offset, __x0, __xs...);
1850-
_M_data1._M_assign_from(integral_constant<int, _Offset + _DataType0::size>(),
1851-
__x0, __xs...);
1852-
}
1868+
_M_data0._M_assign_from(_Offset, __xs...);
1869+
_M_data1._M_assign_from(_Offset + _DataType0::size, __xs...);
18531870
}
18541871

1855-
template <typename _A0>
1856-
[[__gnu__::__always_inline__]]
1857-
static constexpr basic_vec
1858-
_S_concat(const basic_vec<value_type, _A0>& __x0) noexcept
1859-
{ return basic_vec(__x0); }
1872+
[[__gnu__::__always_inline__]]
1873+
static constexpr const basic_vec&
1874+
_S_concat(const basic_vec& __x0) noexcept
1875+
{ return __x0; }
18601876

18611877
template <typename _A0, typename... _As>
18621878
requires (sizeof...(_As) >= 1)
@@ -1875,8 +1891,8 @@ namespace std::simd
18751891
else if (__is_const_known(__x0, __xs...))
18761892
{
18771893
basic_vec __r;
1878-
__r._M_data0.template _M_assign_from(integral_constant<int, 0>(), __x0, __xs...);
1879-
__r._M_data1.template _M_assign_from(_DataType0::size, __x0, __xs...);
1894+
__r._M_data0.template _M_assign_from(cw<0>, __x0, __xs...);
1895+
__r._M_data1.template _M_assign_from(cw<_DataType0::size()>, __x0, __xs...);
18801896
return __r;
18811897
}
18821898
else

tests/creation.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* SPDX-License-Identifier: BSD-3-Clause */
2+
/* Copyright © 2024–2025 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH
3+
* Matthias Kretz <m.kretz@gsi.de>
4+
*/
5+
6+
#include "unittest.h"
7+
8+
template <typename V>
9+
struct Tests
10+
{
11+
using T = typename V::value_type;
12+
using M = typename V::mask_type;
13+
14+
ADD_TEST(VecCatChunk) {
15+
std::tuple{test_iota<V>, test_iota<V, 1>},
16+
[](auto& t, const V v0, const V v1) {
17+
auto c = cat(v0, v1);
18+
t.verify_equal(c.size(), V::size() * 2);
19+
for (int i = 0; i < V::size(); ++i)
20+
t.verify_equal(c[i], v0[i])(i);
21+
for (int i = 0; i < V::size(); ++i)
22+
t.verify_equal(c[i + V::size()], v1[i])(i);
23+
const auto [c0, c1] = simd::chunk<V>(c);
24+
t.verify_equal(c0, v0);
25+
t.verify_equal(c1, v1);
26+
if constexpr (V::size() <= 35)
27+
{
28+
auto d = cat(v1, c, v0);
29+
for (int i = 0; i < V::size(); ++i)
30+
{
31+
t.verify_equal(d[i], v1[i])(i);
32+
t.verify_equal(d[i + V::size()], v0[i])(i);
33+
t.verify_equal(d[i + 2 * V::size()], v1[i])(i);
34+
t.verify_equal(d[i + 3 * V::size()], v0[i])(i);
35+
}
36+
const auto [...chunked] = simd::chunk<3>(d);
37+
t.verify_equal(cat(chunked...), d);
38+
}
39+
}
40+
};
41+
42+
ADD_TEST(MaskCatChunk) {
43+
std::tuple{M([](int i) { return 1 == (i & 1); }), M([](int i) { return 1 == (i % 3); })},
44+
[](auto& t, const M k0, const M k1) {
45+
auto c = cat(k0, k1);
46+
t.verify_equal(c.size(), V::size() * 2);
47+
for (int i = 0; i < V::size(); ++i)
48+
t.verify_equal(c[i], k0[i])(i);
49+
for (int i = 0; i < V::size(); ++i)
50+
t.verify_equal(c[i + V::size()], k1[i])(i);
51+
const auto [c0, c1] = simd::chunk<M>(c);
52+
t.verify_equal(c0, k0);
53+
t.verify_equal(c1, k1);
54+
if constexpr (V::size() <= 35)
55+
{
56+
auto d = cat(k1, c, k0);
57+
for (int i = 0; i < V::size(); ++i)
58+
{
59+
t.verify_equal(d[i], k1[i])(i);
60+
t.verify_equal(d[i + V::size()], k0[i])(i);
61+
t.verify_equal(d[i + 2 * V::size()], k1[i])(i);
62+
t.verify_equal(d[i + 3 * V::size()], k0[i])(i);
63+
}
64+
const auto [...chunked] = simd::chunk<3>(d);
65+
t.verify_equal(cat(chunked...), d);
66+
}
67+
}
68+
};
69+
};

tests/mask.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -184,34 +184,6 @@ template <typename V>
184184
}
185185
};
186186
#endif
187-
188-
ADD_TEST(Cat_n_Chunk) {
189-
std::tuple{M([](int i) { return 1 == (i & 1); }), M([](int i) { return 1 == (i % 3); })},
190-
[](auto& t, const M k0, const M k1) {
191-
auto c = cat(k0, k1);
192-
t.verify_equal(c.size(), V::size() * 2);
193-
for (int i = 0; i < V::size(); ++i)
194-
t.verify_equal(c[i], k0[i])(i);
195-
for (int i = 0; i < V::size(); ++i)
196-
t.verify_equal(c[i + V::size()], k1[i])(i);
197-
const auto [c0, c1] = simd::chunk<M>(c);
198-
t.verify_equal(c0, k0);
199-
t.verify_equal(c1, k1);
200-
if constexpr (V::size() <= 35)
201-
{
202-
auto d = cat(k1, c, k0);
203-
for (int i = 0; i < V::size(); ++i)
204-
{
205-
t.verify_equal(d[i], k1[i])(i);
206-
t.verify_equal(d[i + V::size()], k0[i])(i);
207-
t.verify_equal(d[i + 2 * V::size()], k1[i])(i);
208-
t.verify_equal(d[i + 3 * V::size()], k0[i])(i);
209-
}
210-
const auto [...chunked] = simd::chunk<3>(d);
211-
t.verify_equal(cat(chunked...), d);
212-
}
213-
}
214-
};
215187
};
216188

217189
#include "unittest.h"

0 commit comments

Comments
 (0)