Line | Branch | Exec | Source |
---|---|---|---|
1 | // SPDX-FileCopyrightText: 2023 - 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
2 | // | ||
3 | // SPDX-License-Identifier: Apache-2.0 | ||
4 | |||
5 | #ifndef KLEIDICV_SVE2_H | ||
6 | #define KLEIDICV_SVE2_H | ||
7 | |||
8 | #include <arm_sve.h> | ||
9 | |||
10 | #include <utility> | ||
11 | |||
12 | #include "kleidicv/operations.h" | ||
13 | #include "kleidicv/utils.h" | ||
14 | |||
15 | // It is used by SVE2 and SME, the actual namespace will reflect it. | ||
16 | namespace KLEIDICV_TARGET_NAMESPACE { | ||
17 | |||
18 | // Context associated with SVE operations. | ||
19 | class Context { | ||
20 | public: | ||
21 | 91915 | explicit Context(svbool_t &pg) KLEIDICV_STREAMING : pg_{pg} {} | |
22 | |||
23 | // Sets the predicate associated with the context to a given predicate. | ||
24 | 111260 | void set_predicate(svbool_t pg) KLEIDICV_STREAMING { pg_ = pg; } | |
25 | |||
26 | // Returns predicate associated with the context. | ||
27 | 588580 | svbool_t predicate() const KLEIDICV_STREAMING { return pg_; } | |
28 | |||
29 | protected: | ||
30 | // Hold a reference to an svbool_t because a sizeless type cannot be a member. | ||
31 | svbool_t &pg_; | ||
32 | }; // end of class Context | ||
33 | |||
34 | // Primary template to describe logically grouped properties of vectors. | ||
35 | template <typename ScalarType> | ||
36 | class VectorTypes; | ||
37 | |||
38 | template <> | ||
39 | class VectorTypes<int8_t> { | ||
40 | public: | ||
41 | using ScalarType = int8_t; | ||
42 | using VectorType = svint8_t; | ||
43 | using Vector2Type = svint8x2_t; | ||
44 | using Vector3Type = svint8x3_t; | ||
45 | using Vector4Type = svint8x4_t; | ||
46 | }; // end of class VectorTypes<int8_t> | ||
47 | |||
48 | template <> | ||
49 | class VectorTypes<uint8_t> { | ||
50 | public: | ||
51 | using ScalarType = uint8_t; | ||
52 | using VectorType = svuint8_t; | ||
53 | using Vector2Type = svuint8x2_t; | ||
54 | using Vector3Type = svuint8x3_t; | ||
55 | using Vector4Type = svuint8x4_t; | ||
56 | }; // end of class VectorTypes<uint8_t> | ||
57 | |||
58 | template <> | ||
59 | class VectorTypes<int16_t> { | ||
60 | public: | ||
61 | using ScalarType = int16_t; | ||
62 | using VectorType = svint16_t; | ||
63 | using Vector2Type = svint16x2_t; | ||
64 | using Vector3Type = svint16x3_t; | ||
65 | using Vector4Type = svint16x4_t; | ||
66 | }; // end of class VectorTypes<int16_t> | ||
67 | |||
68 | template <> | ||
69 | class VectorTypes<uint16_t> { | ||
70 | public: | ||
71 | using ScalarType = uint16_t; | ||
72 | using VectorType = svuint16_t; | ||
73 | using Vector2Type = svuint16x2_t; | ||
74 | using Vector3Type = svuint16x3_t; | ||
75 | using Vector4Type = svuint16x4_t; | ||
76 | }; // end of class VectorTypes<uint16_t> | ||
77 | |||
78 | template <> | ||
79 | class VectorTypes<int32_t> { | ||
80 | public: | ||
81 | using ScalarType = int32_t; | ||
82 | using VectorType = svint32_t; | ||
83 | using Vector2Type = svint32x2_t; | ||
84 | using Vector3Type = svint32x3_t; | ||
85 | using Vector4Type = svint32x4_t; | ||
86 | }; // end of class VectorTypes<int32_t> | ||
87 | |||
88 | template <> | ||
89 | class VectorTypes<uint32_t> { | ||
90 | public: | ||
91 | using ScalarType = uint32_t; | ||
92 | using VectorType = svuint32_t; | ||
93 | using Vector2Type = svuint32x2_t; | ||
94 | using Vector3Type = svuint32x3_t; | ||
95 | using Vector4Type = svuint32x4_t; | ||
96 | }; // end of class VectorTypes<uint32_t> | ||
97 | |||
98 | template <> | ||
99 | class VectorTypes<int64_t> { | ||
100 | public: | ||
101 | using ScalarType = int64_t; | ||
102 | using VectorType = svint64_t; | ||
103 | using Vector2Type = svint64x2_t; | ||
104 | using Vector3Type = svint64x3_t; | ||
105 | using Vector4Type = svint64x4_t; | ||
106 | }; // end of class VectorTypes<int64_t> | ||
107 | |||
108 | template <> | ||
109 | class VectorTypes<uint64_t> { | ||
110 | public: | ||
111 | using ScalarType = uint64_t; | ||
112 | using VectorType = svuint64_t; | ||
113 | using Vector2Type = svuint64x2_t; | ||
114 | using Vector3Type = svuint64x3_t; | ||
115 | using Vector4Type = svuint64x4_t; | ||
116 | }; // end of class VectorTypes<uint64_t> | ||
117 | |||
118 | template <> | ||
119 | class VectorTypes<float> { | ||
120 | public: | ||
121 | using ScalarType = float; | ||
122 | using VectorType = svfloat32_t; | ||
123 | using Vector2Type = svfloat32x2_t; | ||
124 | using Vector3Type = svfloat32x3_t; | ||
125 | using Vector4Type = svfloat32x4_t; | ||
126 | }; // end of class VectorTypes<float> | ||
127 | |||
128 | template <> | ||
129 | class VectorTypes<double> { | ||
130 | public: | ||
131 | using ScalarType = double; | ||
132 | using VectorType = svfloat64_t; | ||
133 | using Vector2Type = svfloat64x2_t; | ||
134 | using Vector3Type = svfloat64x3_t; | ||
135 | using Vector4Type = svfloat64x4_t; | ||
136 | }; // end of class VectorTypes<double> | ||
137 | |||
138 | // Base class for all SVE vector traits. | ||
139 | template <typename ScalarType> | ||
140 | class VecTraitsBase : public VectorTypes<ScalarType> { | ||
141 | public: | ||
142 | using typename VectorTypes<ScalarType>::VectorType; | ||
143 | using typename VectorTypes<ScalarType>::Vector2Type; | ||
144 | |||
145 | // Number of lanes in a vector. | ||
146 | 138552 | static inline size_t num_lanes() KLEIDICV_STREAMING { | |
147 | 138552 | return static_cast<size_t>(svcnt()); | |
148 | } | ||
149 | |||
150 | // Maximum number of lanes in a vector. | ||
151 | static constexpr size_t max_num_lanes() KLEIDICV_STREAMING { | ||
152 | return 256 / sizeof(ScalarType); | ||
153 | } | ||
154 | |||
155 | // Loads a single vector from 'src'. | ||
156 | 25810 | static inline void load(Context ctx, const ScalarType *src, | |
157 | VectorType &vec) KLEIDICV_STREAMING { | ||
158 | 25810 | vec = svld1(ctx.predicate(), &src[0]); | |
159 | 25810 | } | |
160 | |||
161 | // Loads two consecutive vectors from 'src'. | ||
162 | 93385 | static inline void load_consecutive(Context ctx, const ScalarType *src, | |
163 | VectorType &vec_0, | ||
164 | VectorType &vec_1) KLEIDICV_STREAMING { | ||
165 | #if KLEIDICV_TARGET_SME2 | ||
166 | // Assuming that ctx contains a full predicate. | ||
167 | (void)ctx; | ||
168 | svcount_t p_counter = svptrue_c(); | ||
169 | Vector2Type v = svld1_x2(p_counter, &src[0]); | ||
170 | vec_0 = svget2(v, 0); | ||
171 | vec_1 = svget2(v, 1); | ||
172 | #else | ||
173 | 93385 | vec_0 = svld1(ctx.predicate(), &src[0]); | |
174 | 93385 | vec_1 = svld1_vnum(ctx.predicate(), &src[0], 1); | |
175 | #endif | ||
176 | 93385 | } | |
177 | |||
178 | // Stores a single vector to 'dst'. | ||
179 | 17416 | static inline void store(Context ctx, VectorType vec, | |
180 | ScalarType *dst) KLEIDICV_STREAMING { | ||
181 | 17416 | svst1(ctx.predicate(), &dst[0], vec); | |
182 | 17416 | } | |
183 | |||
184 | // Stores two consecutive vectors to 'dst'. | ||
185 | 60533 | static inline void store_consecutive(Context ctx, VectorType vec_0, | |
186 | VectorType vec_1, | ||
187 | ScalarType *dst) KLEIDICV_STREAMING { | ||
188 | #if KLEIDICV_TARGET_SME2 | ||
189 | // Assuming that ctx contains a full predicate. | ||
190 | (void)ctx; | ||
191 | svcount_t p_counter = svptrue_c(); | ||
192 | Vector2Type v = svcreate2(vec_0, vec_1); | ||
193 | svst1(p_counter, &dst[0], v); | ||
194 | #else | ||
195 | 60533 | svst1(ctx.predicate(), &dst[0], vec_0); | |
196 | 60533 | svst1_vnum(ctx.predicate(), &dst[0], 1, vec_1); | |
197 | #endif | ||
198 | 60533 | } | |
199 | |||
200 | template <typename T = ScalarType> | ||
201 | 74476 | static std::enable_if_t<sizeof(T) == sizeof(int8_t), uint64_t> svcnt() | |
202 | KLEIDICV_STREAMING { | ||
203 | 74476 | return svcntb(); | |
204 | } | ||
205 | |||
206 | template <typename T = ScalarType> | ||
207 | 34978 | static std::enable_if_t<sizeof(T) == sizeof(int16_t), uint64_t> svcnt() | |
208 | KLEIDICV_STREAMING { | ||
209 | 34978 | return svcnth(); | |
210 | } | ||
211 | |||
212 | template <typename T = ScalarType> | ||
213 | 27244 | static std::enable_if_t<sizeof(T) == sizeof(int32_t), uint64_t> svcnt() | |
214 | KLEIDICV_STREAMING { | ||
215 | 27244 | return svcntw(); | |
216 | } | ||
217 | |||
218 | template <typename T = ScalarType> | ||
219 | 1854 | static std::enable_if_t<sizeof(T) == sizeof(int64_t), uint64_t> svcnt() | |
220 | KLEIDICV_STREAMING { | ||
221 | 1854 | return svcntd(); | |
222 | } | ||
223 | |||
224 | template <typename T = ScalarType> | ||
225 | 454 | static std::enable_if_t<sizeof(T) == sizeof(int8_t), uint64_t> svcntp( | |
226 | svbool_t pg) KLEIDICV_STREAMING { | ||
227 | 454 | return svcntp_b8(pg, pg); | |
228 | } | ||
229 | |||
230 | template <typename T = ScalarType> | ||
231 | static std::enable_if_t<sizeof(T) == sizeof(int16_t), uint64_t> svcntp( | ||
232 | svbool_t pg) KLEIDICV_STREAMING { | ||
233 | return svcntp_b16(pg, pg); | ||
234 | } | ||
235 | |||
236 | template <typename T = ScalarType> | ||
237 | static std::enable_if_t<sizeof(T) == sizeof(int32_t), uint64_t> svcntp( | ||
238 | svbool_t pg) KLEIDICV_STREAMING { | ||
239 | return svcntp_b32(pg, pg); | ||
240 | } | ||
241 | |||
242 | template <typename T = ScalarType> | ||
243 | static std::enable_if_t<sizeof(T) == sizeof(int64_t), uint64_t> svcntp( | ||
244 | svbool_t pg) KLEIDICV_STREAMING { | ||
245 | return svcntp_b64(pg, pg); | ||
246 | } | ||
247 | |||
248 | template <typename T = ScalarType> | ||
249 | 91075 | static std::enable_if_t<sizeof(T) == sizeof(int8_t), svbool_t> svptrue() | |
250 | KLEIDICV_STREAMING { | ||
251 | 91075 | return svptrue_b8(); | |
252 | } | ||
253 | |||
254 | template <typename T = ScalarType> | ||
255 | 52995 | static std::enable_if_t<sizeof(T) == sizeof(int16_t), svbool_t> svptrue() | |
256 | KLEIDICV_STREAMING { | ||
257 | 52995 | return svptrue_b16(); | |
258 | } | ||
259 | |||
260 | template <typename T = ScalarType> | ||
261 | 83772 | static std::enable_if_t<sizeof(T) == sizeof(int32_t), svbool_t> svptrue() | |
262 | KLEIDICV_STREAMING { | ||
263 | 83772 | return svptrue_b32(); | |
264 | } | ||
265 | |||
266 | template <typename T = ScalarType> | ||
267 | 10908 | static std::enable_if_t<sizeof(T) == sizeof(int64_t), svbool_t> svptrue() | |
268 | KLEIDICV_STREAMING { | ||
269 | 10908 | return svptrue_b64(); | |
270 | } | ||
271 | |||
272 | #if KLEIDICV_TARGET_SME2 | ||
273 | template <typename T = ScalarType> | ||
274 | static std::enable_if_t<sizeof(T) == sizeof(int8_t), svcount_t> svptrue_c() | ||
275 | KLEIDICV_STREAMING { | ||
276 | return svptrue_c8(); | ||
277 | } | ||
278 | |||
279 | template <typename T = ScalarType> | ||
280 | static std::enable_if_t<sizeof(T) == sizeof(int16_t), svcount_t> svptrue_c() | ||
281 | KLEIDICV_STREAMING { | ||
282 | return svptrue_c16(); | ||
283 | } | ||
284 | |||
285 | template <typename T = ScalarType> | ||
286 | static std::enable_if_t<sizeof(T) == sizeof(int32_t), svcount_t> svptrue_c() | ||
287 | KLEIDICV_STREAMING { | ||
288 | return svptrue_c32(); | ||
289 | } | ||
290 | |||
291 | template <typename T = ScalarType> | ||
292 | static std::enable_if_t<sizeof(T) == sizeof(int64_t), svcount_t> svptrue_c() | ||
293 | KLEIDICV_STREAMING { | ||
294 | return svptrue_c64(); | ||
295 | } | ||
296 | #endif // KLEIDICV_TARGET_SME2 | ||
297 | |||
298 | template <enum svpattern pat, typename T = ScalarType> | ||
299 | 49104 | static std::enable_if_t<sizeof(T) == sizeof(int8_t), svbool_t> svptrue_pat() | |
300 | KLEIDICV_STREAMING { | ||
301 | 49104 | return svptrue_pat_b8(pat); | |
302 | } | ||
303 | |||
304 | template <enum svpattern pat, typename T = ScalarType> | ||
305 | 59008 | static std::enable_if_t<sizeof(T) == sizeof(int16_t), svbool_t> svptrue_pat() | |
306 | KLEIDICV_STREAMING { | ||
307 | 59008 | return svptrue_pat_b16(pat); | |
308 | } | ||
309 | |||
310 | template <enum svpattern pat, typename T = ScalarType> | ||
311 | 68704 | static std::enable_if_t<sizeof(T) == sizeof(int32_t), svbool_t> svptrue_pat() | |
312 | KLEIDICV_STREAMING { | ||
313 | 68704 | return svptrue_pat_b32(pat); | |
314 | } | ||
315 | |||
316 | template <enum svpattern pat, typename T = ScalarType> | ||
317 | static std::enable_if_t<sizeof(T) == sizeof(int64_t), svbool_t> svptrue_pat() | ||
318 | KLEIDICV_STREAMING { | ||
319 | return svptrue_pat_b64(pat); | ||
320 | } | ||
321 | |||
322 | template <typename IndexType, typename T = ScalarType> | ||
323 | 65427 | static std::enable_if_t<sizeof(T) == sizeof(int8_t), svbool_t> svwhilelt( | |
324 | IndexType index, IndexType max_index) KLEIDICV_STREAMING { | ||
325 | if constexpr (std::is_same_v<IndexType, size_t>) { | ||
326 | 59811 | return svwhilelt_b8_u64(index, max_index); | |
327 | } else if constexpr (std::is_same_v<IndexType, ptrdiff_t>) { | ||
328 | 5616 | return svwhilelt_b8_s64(index, max_index); | |
329 | } else { | ||
330 | return svwhilelt_b8(index, max_index); | ||
331 | } | ||
332 | } | ||
333 | |||
334 | template <typename IndexType, typename T = ScalarType> | ||
335 | 22205 | static std::enable_if_t<sizeof(T) == sizeof(int16_t), svbool_t> svwhilelt( | |
336 | IndexType index, IndexType max_index) KLEIDICV_STREAMING { | ||
337 | if constexpr (std::is_same_v<IndexType, size_t>) { | ||
338 | 19253 | return svwhilelt_b16_u64(index, max_index); | |
339 | } else if constexpr (std::is_same_v<IndexType, ptrdiff_t>) { | ||
340 | 2952 | return svwhilelt_b16_s64(index, max_index); | |
341 | } else { | ||
342 | return svwhilelt_b16(index, max_index); | ||
343 | } | ||
344 | } | ||
345 | |||
346 | template <typename IndexType, typename T = ScalarType> | ||
347 | 16009 | static std::enable_if_t<sizeof(T) == sizeof(int32_t), svbool_t> svwhilelt( | |
348 | IndexType index, IndexType max_index) KLEIDICV_STREAMING { | ||
349 | if constexpr (std::is_same_v<IndexType, size_t>) { | ||
350 | 16009 | return svwhilelt_b32_u64(index, max_index); | |
351 | } else if constexpr (std::is_same_v<IndexType, ptrdiff_t>) { | ||
352 | return svwhilelt_b32_s64(index, max_index); | ||
353 | } else { | ||
354 | return svwhilelt_b32(index, max_index); | ||
355 | } | ||
356 | } | ||
357 | |||
358 | template <typename IndexType, typename T = ScalarType> | ||
359 | 1342 | static std::enable_if_t<sizeof(T) == sizeof(int64_t), svbool_t> svwhilelt( | |
360 | IndexType index, IndexType max_index) KLEIDICV_STREAMING { | ||
361 | if constexpr (std::is_same_v<IndexType, size_t>) { | ||
362 | 1342 | return svwhilelt_b64_u64(index, max_index); | |
363 | } else if constexpr (std::is_same_v<IndexType, ptrdiff_t>) { | ||
364 | return svwhilelt_b64_s64(index, max_index); | ||
365 | } else { | ||
366 | return svwhilelt_b64(index, max_index); | ||
367 | } | ||
368 | } | ||
369 | |||
370 | // Transforms a single predicate into three other predicates that then can be | ||
371 | // used for consecutive operations. The input predicate can only have | ||
372 | // consecutive ones starting at the lowest element. | ||
373 | 120 | static void make_consecutive_predicates(svbool_t pg, svbool_t &pg_0, | |
374 | svbool_t &pg_1, | ||
375 | svbool_t &pg_2) KLEIDICV_STREAMING { | ||
376 | // Length of data. Must be signed because of the unconditional subtraction | ||
377 | // of fixed values. | ||
378 | 120 | int64_t length = 3 * svcntp(pg); | |
379 | // Handle up to VL length worth of data with the first predicated operation. | ||
380 | 120 | pg_0 = svwhilelt(int64_t{0}, length); | |
381 | // Handle up to VL length worth of data with the second predicated operation | ||
382 | // taking into account data stored in the first predicated operation. | ||
383 | 120 | length -= num_lanes(); | |
384 | 120 | pg_1 = svwhilelt(int64_t{0}, length); | |
385 | // Handle up to VL length worth of data with the second predicated operation | ||
386 | // taking into account data stored in the first and second predicated | ||
387 | // operations. | ||
388 | 120 | length -= num_lanes(); | |
389 | 120 | pg_2 = svwhilelt(int64_t{0}, length); | |
390 | 120 | } | |
391 | |||
392 | // Transforms a single predicate into four other predicates that then can be | ||
393 | // used for consecutive operations. The input predicate can only have | ||
394 | // consecutive ones starting at the lowest element. | ||
395 | 334 | static void make_consecutive_predicates(svbool_t pg, svbool_t &pg_0, | |
396 | svbool_t &pg_1, svbool_t &pg_2, | ||
397 | svbool_t &pg_3) KLEIDICV_STREAMING { | ||
398 | // Length of data. Must be signed because of the unconditional subtraction | ||
399 | // of fixed values. | ||
400 | 334 | int64_t length = 4 * svcntp(pg); | |
401 | // Handle up to VL length worth of data with the first predicated operation. | ||
402 | 334 | pg_0 = svwhilelt(int64_t{0}, length); | |
403 | // Handle up to VL length worth of data with the second predicated operation | ||
404 | // taking into account data stored in the first predicated operation. | ||
405 | 334 | length -= num_lanes(); | |
406 | 334 | pg_1 = svwhilelt(int64_t{0}, length); | |
407 | // Handle up to VL length worth of data with the second predicated operation | ||
408 | // taking into account data stored in the first and second predicated | ||
409 | // operations. | ||
410 | 334 | length -= num_lanes(); | |
411 | 334 | pg_2 = svwhilelt(int64_t{0}, length); | |
412 | // Handle up to VL length worth of data with the third predicated operation | ||
413 | // taking into account data stored in the first, second and third predicated | ||
414 | // operations. | ||
415 | 334 | length -= num_lanes(); | |
416 | 334 | pg_3 = svwhilelt(int64_t{0}, length); | |
417 | 334 | } | |
418 | }; // end of class VecTraitsBase<ScalarType> | ||
419 | |||
420 | // Primary template for SVE vector traits. | ||
421 | template <typename ScalarType> | ||
422 | class VecTraits : public VecTraitsBase<ScalarType> {}; | ||
423 | |||
424 | template <> | ||
425 | class VecTraits<int8_t> : public VecTraitsBase<int8_t> { | ||
426 | public: | ||
427 | 592 | static inline svint8_t svdup(int8_t v) KLEIDICV_STREAMING { | |
428 | 592 | return svdup_s8(v); | |
429 | } | ||
430 | 2079 | static inline svint8_t svreinterpret(svuint8_t v) KLEIDICV_STREAMING { | |
431 | 2079 | return svreinterpret_s8(v); | |
432 | } | ||
433 | 1386 | static inline svint8_t svasr_n(svbool_t pg, svint8_t v, | |
434 | uint8_t s) KLEIDICV_STREAMING { | ||
435 | 1386 | return svasr_n_s8_x(pg, v, s); | |
436 | } | ||
437 | }; // end of class VecTraits<int8_t> | ||
438 | |||
439 | template <> | ||
440 | class VecTraits<uint8_t> : public VecTraitsBase<uint8_t> { | ||
441 | public: | ||
442 | 5803 | static inline svuint8_t svdup(uint8_t v) KLEIDICV_STREAMING { | |
443 | 5803 | return svdup_u8(v); | |
444 | } | ||
445 | 2079 | static inline svuint8_t svreinterpret(svint8_t v) KLEIDICV_STREAMING { | |
446 | 2079 | return svreinterpret_u8(v); | |
447 | } | ||
448 | static inline svuint8_t svsub(svbool_t pg, svuint8_t v, | ||
449 | svuint8_t u) KLEIDICV_STREAMING { | ||
450 | return svsub_u8_x(pg, v, u); | ||
451 | } | ||
452 | static inline svuint8_t svhsub(svbool_t pg, svuint8_t v, | ||
453 | svuint8_t u) KLEIDICV_STREAMING { | ||
454 | return svhsub_u8_x(pg, v, u); | ||
455 | } | ||
456 | }; // end of class VecTraits<uint8_t> | ||
457 | |||
458 | template <> | ||
459 | class VecTraits<int16_t> : public VecTraitsBase<int16_t> { | ||
460 | public: | ||
461 | 2511 | static inline svint16_t svdup(int16_t v) KLEIDICV_STREAMING { | |
462 | 2511 | return svdup_s16(v); | |
463 | } | ||
464 | static inline svint16_t svreinterpret(svuint16_t v) KLEIDICV_STREAMING { | ||
465 | return svreinterpret_s16(v); | ||
466 | } | ||
467 | }; // end of class VecTraits<int16_t> | ||
468 | |||
469 | template <> | ||
470 | class VecTraits<uint16_t> : public VecTraitsBase<uint16_t> { | ||
471 | public: | ||
472 | 1692 | static inline svuint16_t svdup(uint16_t v) KLEIDICV_STREAMING { | |
473 | 1692 | return svdup_u16(v); | |
474 | } | ||
475 | static inline svuint16_t svreinterpret(svint16_t v) KLEIDICV_STREAMING { | ||
476 | return svreinterpret_u16(v); | ||
477 | } | ||
478 | }; // end of class VecTraits<uint16_t> | ||
479 | |||
480 | template <> | ||
481 | class VecTraits<int32_t> : public VecTraitsBase<int32_t> { | ||
482 | public: | ||
483 | 1132 | static inline svint32_t svdup(int32_t v) KLEIDICV_STREAMING { | |
484 | 1132 | return svdup_s32(v); | |
485 | } | ||
486 | static inline svint32_t svreinterpret(svuint32_t v) KLEIDICV_STREAMING { | ||
487 | return svreinterpret_s32(v); | ||
488 | } | ||
489 | }; // end of class VecTraits<int32_t> | ||
490 | |||
491 | template <> | ||
492 | class VecTraits<uint32_t> : public VecTraitsBase<uint32_t> { | ||
493 | public: | ||
494 | 550 | static inline svuint32_t svdup(uint32_t v) KLEIDICV_STREAMING { | |
495 | 550 | return svdup_u32(v); | |
496 | } | ||
497 | static inline svuint32_t svreinterpret(svint32_t v) KLEIDICV_STREAMING { | ||
498 | return svreinterpret_u32(v); | ||
499 | } | ||
500 | }; // end of class VecTraits<uint32_t> | ||
501 | |||
502 | template <> | ||
503 | class VecTraits<int64_t> : public VecTraitsBase<int64_t> { | ||
504 | public: | ||
505 | static inline svint64_t svdup(int64_t v) KLEIDICV_STREAMING { | ||
506 | return svdup_s64(v); | ||
507 | } | ||
508 | static inline svint64_t svreinterpret(svuint64_t v) KLEIDICV_STREAMING { | ||
509 | return svreinterpret_s64(v); | ||
510 | } | ||
511 | }; // end of class VecTraits<int64_t> | ||
512 | |||
513 | template <> | ||
514 | class VecTraits<uint64_t> : public VecTraitsBase<uint64_t> { | ||
515 | public: | ||
516 | static inline svuint64_t svdup(uint64_t v) KLEIDICV_STREAMING { | ||
517 | return svdup_u64(v); | ||
518 | } | ||
519 | static inline svuint64_t svreinterpret(svint64_t v) KLEIDICV_STREAMING { | ||
520 | return svreinterpret_u64(v); | ||
521 | } | ||
522 | }; // end of class VecTraits<uint64_t> | ||
523 | |||
524 | template <> | ||
525 | class VecTraits<float> : public VecTraitsBase<float> { | ||
526 | public: | ||
527 | 600 | static inline svfloat32_t svdup(float v) KLEIDICV_STREAMING { | |
528 | 600 | return svdup_f32(v); | |
529 | } | ||
530 | static inline svfloat32_t svsub(svbool_t pg, svfloat32_t v, | ||
531 | svfloat32_t u) KLEIDICV_STREAMING { | ||
532 | return svsub_f32_x(pg, v, u); | ||
533 | } | ||
534 | }; // end of class VecTraits<float> | ||
535 | |||
536 | template <> | ||
537 | class VecTraits<double> : public VecTraitsBase<double> { | ||
538 | public: | ||
539 | 28 | static inline svfloat64_t svdup(double v) KLEIDICV_STREAMING { | |
540 | 28 | return svdup_f64(v); | |
541 | } | ||
542 | }; // end of class VecTraits<double> | ||
543 | |||
544 | // Adapter which adds context and forwards arguments. | ||
545 | template <typename OperationType> | ||
546 | class OperationContextAdapter : public OperationBase<OperationType> { | ||
547 | // Shorten rows: no need to write 'this->'. | ||
548 | using OperationBase<OperationType>::operation; | ||
549 | using OperationBase<OperationType>::num_lanes; | ||
550 | |||
551 | public: | ||
552 | using ContextType = Context; | ||
553 | using VecTraits = typename OperationBase<OperationType>::VecTraits; | ||
554 | |||
555 | 13616 | explicit OperationContextAdapter(OperationType &operation) KLEIDICV_STREAMING | |
556 | 13616 | : OperationBase<OperationType>(operation) {} | |
557 | |||
558 | // Forwards vector_path_2x() calls to the inner operation. | ||
559 | template <typename... ArgTypes> | ||
560 | 66227 | void vector_path_2x(ArgTypes &&...args) KLEIDICV_STREAMING { | |
561 | 66227 | svbool_t ctx_pg; | |
562 | 66227 | ContextType ctx{ctx_pg}; | |
563 | 66227 | ctx.set_predicate(VecTraits::svptrue()); | |
564 | 66227 | operation().vector_path_2x(ctx, std::forward<ArgTypes>(args)...); | |
565 | 66227 | } | |
566 | |||
567 | // Forwards vector_path() calls to the inner operation. | ||
568 | template <typename... ArgTypes> | ||
569 | 9090 | void vector_path(ArgTypes &&...args) KLEIDICV_STREAMING { | |
570 | 9090 | svbool_t ctx_pg; | |
571 | 9090 | ContextType ctx{ctx_pg}; | |
572 | 9090 | ctx.set_predicate(VecTraits::svptrue()); | |
573 | 9090 | operation().vector_path(ctx, std::forward<ArgTypes>(args)...); | |
574 | 9090 | } | |
575 | |||
576 | // Forwards remaining_path() calls to the inner operation if the concrete | ||
577 | // operation is unrolled once. | ||
578 | template <typename... ColumnTypes, typename T = OperationType> | ||
579 | 955 | std::enable_if_t<T::is_unrolled_once()> remaining_path( | |
580 | size_t length, ColumnTypes &&...columns) KLEIDICV_STREAMING { | ||
581 | 955 | svbool_t ctx_pg; | |
582 | 955 | ContextType ctx{ctx_pg}; | |
583 | 955 | ctx.set_predicate(VecTraits::svwhilelt(size_t{0}, length)); | |
584 | 955 | operation().remaining_path(ctx, std::forward<ColumnTypes>(columns)...); | |
585 | 955 | } | |
586 | |||
587 | // Forwards remaining_path() calls to the inner operation if the concrete | ||
588 | // operation is not unrolled once. | ||
589 | template <typename... ColumnTypes, typename T = OperationType> | ||
590 | 15643 | std::enable_if_t<!T::is_unrolled_once()> remaining_path( | |
591 | size_t length, ColumnTypes... columns) KLEIDICV_STREAMING { | ||
592 | 15643 | svbool_t ctx_pg; | |
593 | 15643 | ContextType ctx{ctx_pg}; | |
594 | |||
595 | 15643 | size_t index = 0; | |
596 | 15643 | ctx.set_predicate(VecTraits::svwhilelt(index, length)); | |
597 |
16/16✓ Branch 0 taken 7167 times.
✓ Branch 1 taken 6258 times.
✓ Branch 2 taken 5480 times.
✓ Branch 3 taken 3518 times.
✓ Branch 4 taken 1989 times.
✓ Branch 5 taken 1772 times.
✓ Branch 6 taken 1748 times.
✓ Branch 7 taken 1542 times.
✓ Branch 8 taken 1312 times.
✓ Branch 9 taken 1107 times.
✓ Branch 10 taken 963 times.
✓ Branch 11 taken 790 times.
✓ Branch 12 taken 426 times.
✓ Branch 13 taken 408 times.
✓ Branch 14 taken 260 times.
✓ Branch 15 taken 248 times.
|
34988 | while (svptest_first(VecTraits::svptrue(), ctx.predicate())) { |
598 | 19345 | operation().remaining_path(ctx, columns.at(index)...); | |
599 | // Update loop counter and calculate the next governing predicate. | ||
600 | 19345 | index += num_lanes(); | |
601 | 19345 | ctx.set_predicate(VecTraits::svwhilelt(index, length)); | |
602 | } | ||
603 | 15643 | } | |
604 | }; // end of class OperationContextAdapter<OperationType> | ||
605 | |||
606 | // Adapter which implements remaining_path() for general SVE2 operations. | ||
607 | template <typename OperationType> | ||
608 | class RemainingPathAdapter : public OperationBase<OperationType> { | ||
609 | public: | ||
610 | using ContextType = Context; | ||
611 | |||
612 | 13616 | explicit RemainingPathAdapter(OperationType &operation) KLEIDICV_STREAMING | |
613 | 13616 | : OperationBase<OperationType>(operation) {} | |
614 | |||
615 | // Forwards remaining_path() to either vector_path() or tail_path() of the | ||
616 | // inner operation depending on what is requested by the innermost operation. | ||
617 | template <typename... ArgTypes> | ||
618 | 20300 | void remaining_path(ArgTypes... args) KLEIDICV_STREAMING { | |
619 | if constexpr (OperationType::uses_tail_path()) { | ||
620 | 454 | this->operation().tail_path(std::forward<ArgTypes>(args)...); | |
621 | } else { | ||
622 | 19846 | this->operation().vector_path(std::forward<ArgTypes>(args)...); | |
623 | } | ||
624 | 20300 | } | |
625 | }; // end of class RemainingPathAdapter<OperationType> | ||
626 | |||
627 | // Shorthand for applying a generic unrolled SVE2 operation. | ||
628 | template <typename OperationType, typename... ArgTypes> | ||
629 | 12704 | void apply_operation_by_rows(OperationType &operation, | |
630 | ArgTypes &&...args) KLEIDICV_STREAMING { | ||
631 | 12704 | ForwardingOperation forwarding_operation{operation}; | |
632 | 12704 | OperationAdapter operation_adapter{forwarding_operation}; | |
633 | 12704 | RemainingPathAdapter remaining_path_adapter{operation_adapter}; | |
634 | 12704 | OperationContextAdapter context_adapter{remaining_path_adapter}; | |
635 | 12704 | RowBasedOperation row_based_operation{context_adapter}; | |
636 | 12704 | zip_rows(row_based_operation, std::forward<ArgTypes>(args)...); | |
637 | 12704 | } | |
638 | |||
639 | // Swap two variables, since some C++ Standard Library implementations do not | ||
640 | // allow using std::swap for SVE vectors. | ||
641 | template <typename T> | ||
642 | 7008 | static inline void swap_scalable(T &a, T &b) KLEIDICV_STREAMING { | |
643 | 7008 | T tmp = a; | |
644 | 7008 | a = b; | |
645 | 7008 | b = tmp; | |
646 | 7008 | } | |
647 | |||
648 | // The following wrapper is used as a workaround to treat SVE variables as a 2D | ||
649 | // array. | ||
650 | template <typename VectorType, size_t Rows, size_t Cols> | ||
651 | class ScalableVectorArray2D { | ||
652 | public: | ||
653 | std::reference_wrapper<VectorType> window[Rows][Cols]; | ||
654 | 95233404 | VectorType &operator()(int row, int col) KLEIDICV_STREAMING { | |
655 | 95233404 | return window[row][col].get(); | |
656 | } | ||
657 | }; | ||
658 | |||
659 | template <typename VectorType, size_t element_size> | ||
660 | class ScalableVectorArray1D { | ||
661 | public: | ||
662 | std::reference_wrapper<VectorType> window[element_size]; | ||
663 | 6205024 | VectorType &operator()(int index) KLEIDICV_STREAMING { | |
664 | 6205024 | return window[index].get(); | |
665 | } | ||
666 | }; | ||
667 | |||
668 | } // namespace KLEIDICV_TARGET_NAMESPACE | ||
669 | |||
670 | #endif // KLEIDICV_SVE2_H | ||
671 |