Files
netgen/libsrc/core/simd_arm64.hpp
2025-11-09 15:16:08 +01:00

408 lines
11 KiB
C++

#include "arm_neon.h"
namespace ngcore
{
template <>
class SIMD<mask64,2>
{
int64x2_t mask;
public:
SIMD (int i)
{
mask[0] = i > 0 ? -1 : 0;
mask[1] = i > 1 ? -1 : 0;
}
SIMD (bool i0, bool i1) { mask[0] = i0 ? -1 : 0; mask[1] = i1 ? -1 : 0; }
SIMD (SIMD<mask64,1> i0, SIMD<mask64,1> i1) { mask[0] = i0[0]; mask[1] = i1[0]; }
// SIMD (float64x2_t _data) : mask{_data} { }
SIMD (int64x2_t _data) : mask{_data} { }
auto Data() const { return mask; }
static constexpr int Size() { return 2; }
// static NETGEN_INLINE SIMD<mask64, 2> GetMaskFromBits (unsigned int i);
int64_t operator[] (int i) const { return mask[i]; }
template <int I>
int64_t Get() const { return mask[I]; }
auto Lo() const { return mask[0]; }
auto Hi() const { return mask[1]; }
};
// *************************** int32 ***************************
template<>
class SIMD<int32_t,2>
{
int32x2_t data;
public:
static constexpr int Size() { return 2; }
SIMD() {}
SIMD (int32_t val) : data{val,val} {}
SIMD (int32_t v0, int32_t v1) : data{v0,v1} { }
SIMD (SIMD<int32_t,1> lo, SIMD<int32_t,1> hi) : data{lo[0], hi[0] } { }
SIMD (std::array<int32_t, 2> arr) : data{arr[0], arr[1]} { }
SIMD (int32x2_t _data) { data = _data; }
NETGEN_INLINE auto Data() const { return data; }
NETGEN_INLINE auto & Data() { return data; }
SIMD<int32_t,1> Lo() const { return Get<0>(); }
SIMD<int32_t,1> Hi() const { return Get<1>(); }
int32_t operator[] (int i) const { return data[i]; }
int32_t & operator[] (int i) { return ((int32_t*)&data)[i]; }
template <int I>
int32_t Get() const { return data[I]; }
static SIMD FirstInt(int n0=0) { return { n0+0, n0+1 }; }
};
template<>
class SIMD<int32_t,4>
{
int32x4_t data;
public:
static constexpr int Size() { return 4; }
SIMD() {}
SIMD (int32_t val) : data{val,val,val,val} {}
SIMD (int32_t v0, int32_t v1, int32_t v2, int32_t v3) : data{v0,v1,v2,v3} { }
SIMD (std::array<int32_t, 4> arr) : data{arr[0], arr[1], arr[2], arr[3]} { }
SIMD (int32x4_t _data) { data = _data; }
SIMD (SIMD<int32_t,2> lo, SIMD<int32_t,2> hi) : data{vcombine_s32(lo.Data(), hi.Data())} {}
SIMD (int32_t * p) : data{vld1q_s32(p)} { }
NETGEN_INLINE auto Data() const { return data; }
NETGEN_INLINE auto & Data() { return data; }
SIMD<int32_t,2> Lo() const { return vget_low_s32(data); }
SIMD<int32_t,2> Hi() const { return vget_high_s32(data); }
int32_t operator[] (int i) const { return data[i]; }
int32_t & operator[] (int i) { return ((int32_t*)&data)[i]; }
void Store (int32_t * p) { vst1q_s32(p, data); }
template <int I>
int32_t Get() const { return data[I]; }
static SIMD FirstInt(int n0=0) { return { n0+0, n0+1, n0+2, n0+3 }; }
};
NETGEN_INLINE auto Min (SIMD<int32_t,2> a, SIMD<int32_t,2> b) {
return SIMD<int32_t,2>(vmin_s32(a.Data(), b.Data()));
}
NETGEN_INLINE auto Max (SIMD<int32_t,2> a, SIMD<int32_t,2> b) {
return SIMD<int32_t,2>(vmax_s32(a.Data(), b.Data()));
}
NETGEN_INLINE auto Min (SIMD<int32_t,4> a, SIMD<int32_t,4> b) {
return SIMD<int32_t,4>(vminq_s32(a.Data(), b.Data()));
}
NETGEN_INLINE auto Max (SIMD<int32_t,4> a, SIMD<int32_t,4> b) {
return SIMD<int32_t,4>(vmaxq_s32(a.Data(), b.Data()));
}
// *************************** int64 ***************************
template<>
class SIMD<int64_t,2>
{
int64x2_t data;
public:
static constexpr int Size() { return 2; }
SIMD() {}
SIMD (int64_t val) : data{val,val} {}
SIMD (int64_t v0, int64_t v1) : data{vcombine_s64(int64x1_t{v0}, int64x1_t{v1})} { }
SIMD (std::array<int64_t, 2> arr) : data{arr[0], arr[1]} { }
SIMD (int64x2_t _data) { data = _data; }
NETGEN_INLINE auto Data() const { return data; }
NETGEN_INLINE auto & Data() { return data; }
int64_t Lo() const { return Get<0>(); }
int64_t Hi() const { return Get<1>(); }
int64_t operator[] (int i) const { return data[i]; }
int64_t & operator[] (int i) { return ((int64_t*)&data)[i]; }
template <int I>
int64_t Get() const { return data[I]; }
static SIMD FirstInt(int n0=0) { return { n0+0, n0+1 }; }
};
NETGEN_INLINE SIMD<int64_t,2> operator& (SIMD<int64_t,2> a, SIMD<int64_t,2> b)
{
return vandq_s64(a.Data(), b.Data());
}
NETGEN_INLINE SIMD<int64_t,2> operator+ (SIMD<int64_t,2> a, SIMD<int64_t,2> b)
{
return vaddq_s64(a.Data(), b.Data());
}
NETGEN_INLINE SIMD<mask64,2> operator== (SIMD<int64_t> a, SIMD<int64_t> b)
{
return vceqq_u64(a.Data(), b.Data());
}
NETGEN_INLINE SIMD<mask64,2> operator> (SIMD<int64_t> a, SIMD<int64_t> b)
{
return vcgtq_s64(a.Data(), b.Data());
}
template <int N>
SIMD<int64_t,2> operator<< (SIMD<int64_t,2> a, IC<N> n)
{
return vshlq_n_s64(a.Data(), N);
}
// *************************** double ***************************
template<>
class SIMD<double,2>
{
float64x2_t data;
public:
static constexpr int Size() { return 2; }
SIMD () {}
SIMD (const SIMD &) = default;
// SIMD (double v0, double v1) : data{v0,v1} { }
SIMD (double v0, double v1) : data{vcombine_f64(float64x1_t{v0}, float64x1_t{v1})} { }
SIMD (SIMD<double,1> v0, SIMD<double,1> v1) : data{vcombine_f64(float64x1_t{v0.Data()}, float64x1_t{v1.Data()})} { }
SIMD (std::array<double, 2> arr) : data{arr[0], arr[1]} { }
SIMD & operator= (const SIMD &) = default;
SIMD (double val) : data{val,val} { }
SIMD (int val) : data{double(val),double(val)} { }
SIMD (size_t val) : data{double(val),double(val)} { }
SIMD (double const * p)
{
data = vld1q_f64(p);
// data[0] = p[0];
// data[1] = p[1];
}
SIMD (double const * p, SIMD<mask64,2> mask)
{
data[0] = mask[0] ? p[0] : 0;
data[1] = mask[1] ? p[1] : 0;
}
SIMD (float64x2_t _data) { data = _data; }
template<typename T, typename std::enable_if<std::is_convertible<T, std::function<double(int)>>::value, int>::type = 0>
SIMD (const T & func)
{
data[0] = func(0);
data[1] = func(1);
}
void Store (double * p)
{
vst1q_f64(p, data);
/*
p[0] = data[0];
p[1] = data[1];
*/
}
void Store (double * p, SIMD<mask64,2> mask)
{
if (mask[0]) p[0] = data[0];
if (mask[1]) p[1] = data[1];
}
// NETGEN_INLINE double operator[] (int i) const { return ((double*)(&data))[i]; }
NETGEN_INLINE double operator[] (int i) const { return data[i]; }
NETGEN_INLINE double & operator[] (int i) { return ((double*)&data)[i]; }
template <int I>
double Get() const { return data[I]; }
NETGEN_INLINE auto Data() const { return data; }
NETGEN_INLINE auto & Data() { return data; }
operator std::tuple<double&,double&> ()
{
auto pdata = (double*)&data;
return std::tuple<double&,double&>(pdata[0], pdata[1]);
}
double Lo() const { return Get<0>(); } // data[0]; }
double Hi() const { return Get<1>(); } // data[1]; }
// double Hi() const { return vget_high_f64(data)[0]; }
};
NETGEN_INLINE double HSum (SIMD<double,2> sd)
{
return sd.Lo()+sd.Hi(); // sd[0]+sd[1];
}
NETGEN_INLINE SIMD<double,2> HSum (SIMD<double,2> a, SIMD<double,2> b)
{
// return SIMD<double,2> (a[0]+a[1], b[0]+b[1]);
return vpaddq_f64(a.Data(), b.Data());
}
NETGEN_INLINE SIMD<double,4> HSum(SIMD<double,2> a, SIMD<double,2> b, SIMD<double,2> c, SIMD<double,2> d)
{
return SIMD<double,4> (HSum(a,b), HSum(c,d));
}
NETGEN_INLINE SIMD<double,2> SwapPairs (SIMD<double,2> a)
{
return __builtin_shufflevector(a.Data(), a.Data(), 1, 0);
}
// a*b+c
NETGEN_INLINE SIMD<double,2> FMA (SIMD<double,2> a, SIMD<double,2> b, SIMD<double,2> c)
{
return vmlaq_f64(c.Data(), a.Data(), b.Data());
}
NETGEN_INLINE SIMD<double,2> FMA (const double & a, SIMD<double,2> b, SIMD<double,2> c)
{
return FMA(SIMD<double,2> (a), b, c);
}
// -a*b+c
NETGEN_INLINE SIMD<double,2> FNMA (SIMD<double,2> a, SIMD<double,2> b, SIMD<double,2> c)
{
return vmlsq_f64(c.Data(), a.Data(), b.Data());
// return c-a*b;
}
NETGEN_INLINE SIMD<double,2> FNMA (const double & a, SIMD<double,2> b, SIMD<double,2> c)
{
return FNMA(SIMD<double,2> (a), b, c);
}
// ARM complex mult:
// https://arxiv.org/pdf/1901.07294.pdf
// c += a*b (a0re, a0im, a1re, a1im, ...),
NETGEN_INLINE void FMAComplex (SIMD<double,2> a, SIMD<double,2> b, SIMD<double,2> & c)
{
auto tmp = vcmlaq_f64(c.Data(), a.Data(), b.Data()); // are * b
c = vcmlaq_rot90_f64(tmp, a.Data(), b.Data()); // += i*aim * b
}
NETGEN_INLINE void FMAComplex (SIMD<double,4> a, SIMD<double,4> b, SIMD<double,4> & c)
{
SIMD<double,2> clo = c.Lo();
SIMD<double,2> chi = c.Hi();
FMAComplex (a.Lo(), b.Lo(), clo);
FMAComplex (a.Hi(), b.Hi(), chi);
c = SIMD<double,4> (clo, chi);
}
NETGEN_INLINE SIMD<double,2> operator+ (SIMD<double,2> a, SIMD<double,2> b)
{ return a.Data()+b.Data(); }
NETGEN_INLINE SIMD<double,2> operator- (SIMD<double,2> a, SIMD<double,2> b)
{ return a.Data()-b.Data(); }
NETGEN_INLINE SIMD<double,2> operator- (SIMD<double,2> a)
{ return -a.Data(); }
NETGEN_INLINE SIMD<double,2> operator* (SIMD<double,2> a, SIMD<double,2> b)
{ return a.Data()*b.Data(); }
NETGEN_INLINE SIMD<double,2> operator/ (SIMD<double,2> a, SIMD<double,2> b)
{ return a.Data()/b.Data(); }
NETGEN_INLINE SIMD<double,2> sqrt (SIMD<double,2> x)
{ return vsqrtq_f64(x.Data()); }
NETGEN_INLINE SIMD<double,2> round (SIMD<double,2> x)
{
return vrndnq_f64(x.Data());
}
NETGEN_INLINE SIMD<int64_t,2> lround (SIMD<double,2> x)
{
return vcvtq_s64_f64(x.Data());
}
NETGEN_INLINE SIMD<double,2> rsqrt (SIMD<double,2> x)
{
return 1.0 / sqrt(x);
// SIMD<double,2> y = vrsqrteq_f64(x.Data());
/*
y = y * vrsqrtsq_f64( (x*y).Data(), y.Data());
y = y * vrsqrtsq_f64( (x*y).Data(), y.Data());
y = y * vrsqrtsq_f64( (x*y).Data(), y.Data());
*/
/*
auto x_half = 0.5*x;
y = y * (1.5 - (x_half * y * y));
y = y * (1.5 - (x_half * y * y));
y = y * (1.5 - (x_half * y * y));
return y;
*/
}
template <>
NETGEN_INLINE SIMD<double,2> Reinterpret (SIMD<int64_t,2> a)
{
return vreinterpretq_f64_s64(a.Data());
}
NETGEN_INLINE SIMD<double,2> If (SIMD<mask64,2> a, SIMD<double,2> b, SIMD<double,2> c)
{
// return { a[0] ? b[0] : c[0], a[1] ? b[1] : c[1] };
uint64x2_t mask = vreinterpretq_u64_s64(a.Data());
return vbslq_f64(mask, b.Data(), c.Data());
}
NETGEN_INLINE SIMD<int64_t,2> If (SIMD<mask64,2> a, SIMD<int64_t,2> b, SIMD<int64_t,2> c)
{
// return SIMD<int64_t,2> (a[0] ? b[0] : c[0], a[1] ? b[1] : c[1]);
uint64x2_t mask = vreinterpretq_u64_s64(a.Data());
return vbslq_s64(mask, b.Data(), c.Data());
}
NETGEN_INLINE SIMD<mask64,2> operator&& (SIMD<mask64,2> a, SIMD<mask64,2> b)
{
uint64x2_t m1 = vreinterpretq_u64_s64(a.Data());
uint64x2_t m2 = vreinterpretq_u64_s64(b.Data());
uint64x2_t res = vandq_u64 (m1, m2);
return vreinterpretq_s64_u64(res);
}
}