Line | Branch | Exec | Source |
---|---|---|---|
1 | // SPDX-FileCopyrightText: 2023 - 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
2 | // | ||
3 | // SPDX-License-Identifier: Apache-2.0 | ||
4 | |||
5 | #include <climits> | ||
6 | #include <cmath> | ||
7 | #include <cstdint> | ||
8 | #include <limits> | ||
9 | |||
10 | #include "kleidicv/arithmetics/scale.h" | ||
11 | #include "kleidicv/neon.h" | ||
12 | #include "kleidicv/traits.h" | ||
13 | |||
14 | namespace kleidicv::neon { | ||
15 | |||
16 | // Scale algorithm: for each value in the source, | ||
17 | // dst[i] = src[i] * scale + shift (floating point operation) | ||
18 | // | ||
19 | // Unsigned 8-bit implementation | ||
20 | // | ||
21 | // Since converting from uint8 to float32 and back takes more steps, | ||
22 | // 'ScaleTbx' saves time by pre-calculating all 256 values and uses TBLs | ||
23 | // and TBXs to map the values directly from uint8 to uint8: | ||
24 | // i: 0 to 255: tbl[i] = i * scale + shift | ||
25 | // | ||
26 | // Since a single TBL intruction can map only 16 values, more TBX instructions | ||
27 | // needed for the remaining 240 values. After the first TBL (that replaces | ||
28 | // 0-15 values with indexed values from the table) 16 is subtracted from all | ||
29 | // lanes in the source vector before the next TBX is done, so when indexing 0 | ||
30 | // to 15, actually 16 to 31 values are replaced from the original source vector. | ||
31 | // | ||
32 | // Example: | ||
33 | // scale = 1 | ||
34 | // shift = 100 | ||
35 | // Initialization: (it also takes time, so for short inputs it's not used) | ||
36 | // tbl = [ 100, 101, 102, ..., 255, <100 times 255, it's saturated>] | ||
37 | // Copy table to vector registers: | ||
38 | // t0 = [ 100, ..., 115 ] | ||
39 | // t1 = [ 116, ..., 131 ] | ||
40 | // t2 = [ 132, ..., 147 ] | ||
41 | // ... | ||
42 | // t15 = [ 255, ..., 255 ] | ||
43 | // | ||
44 | // input: v = [ 21, 3, 39, 6 ] | ||
45 | // TBL(t0): d = [ 0, 103, 0, 106 ] // index > 16 result in 0 | ||
46 | // SUB: v = [ 5, 243, 23, 246 ] // subtracted 16 --> next table | ||
47 | // TBX(t1): d = [ 121, 103, 0, 106 ] // index > 16 are ignored | ||
48 | // SUB: v = [ 245, 227, 7, 230 ] // subtracted 16 --> next table | ||
49 | // TBX(t2): d = [ 121, 103, 107, 106 ] // index > 16 are ignored | ||
50 | // ... etc. | ||
51 | // | ||
52 | // Bigger index tables (32, 48 or 64 values) can be used by TBX2 - TBX3 - TBX4. | ||
53 | // In this case, instead of 16, 2/3/4 * 16 have to be subtracted from source. | ||
54 | // The below solution (combining TBX2-TBX3) gives a good compromise between code | ||
55 | // size and speed. | ||
56 | |||
57 | template <typename ScalarType> | ||
58 | class ScaleIntBase : public UnrollTwice { | ||
59 | public: | ||
60 | 303 | ScaleIntBase(float scale, float shift) : scale_{scale}, shift_{shift} {} | |
61 | |||
62 | protected: | ||
63 | static constexpr ScalarType ScalarMax = | ||
64 | std::numeric_limits<ScalarType>::max(); | ||
65 | |||
66 | float scale_, shift_; | ||
67 | }; | ||
68 | |||
69 | template <typename T> | ||
70 | kleidicv_error_t scale(const T *src, size_t src_stride, T *dst, | ||
71 | size_t dst_stride, size_t width, size_t height, | ||
72 | float scale, float shift); | ||
73 | |||
74 | template <typename T> | ||
75 | 2713 | T scale_value(T value, float scale, float shift) { | |
76 | static constexpr T ScalarMax = std::numeric_limits<T>::max(); | ||
77 | 2713 | int64_t v = lrintf(static_cast<float>(value) * scale + shift); | |
78 |
2/2✓ Branch 0 taken 2428 times.
✓ Branch 1 taken 285 times.
|
2713 | if (static_cast<uint64_t>(v) <= ScalarMax) { |
79 | 2428 | return static_cast<T>(v); | |
80 | } | ||
81 | 285 | return static_cast<T>(v > 0 ? ScalarMax : 0); | |
82 | 2713 | } | |
83 | |||
84 | class ScaleUint8Tbx final : public ScaleIntBase<uint8_t> { | ||
85 | public: | ||
86 | using ScalarType = uint8_t; | ||
87 | using VecTraits = neon::VecTraits<ScalarType>; | ||
88 | using VectorType = typename VecTraits::VectorType; | ||
89 | using Vector2Type = typename VecTraits::Vector2Type; | ||
90 | using Vector3Type = typename VecTraits::Vector3Type; | ||
91 | |||
92 | 149 | ScaleUint8Tbx(float scale, float shift, const ScalarType *precalculated_table) | |
93 | 149 | : ScaleIntBase<uint8_t>(scale, shift), | |
94 | 149 | table_pointer_(precalculated_table), | |
95 | 149 | v_step3_(vdupq_n_u8(3 * VecTraits::num_lanes())), | |
96 | 149 | v_step2_(vdupq_n_u8(2 * VecTraits::num_lanes())) { | |
97 | 149 | VecTraits::load(precalculated_table, t0_3_); | |
98 | 149 | VecTraits::load(precalculated_table + 3 * VecTraits::num_lanes(), t1_3_); | |
99 | 298 | VecTraits::load(precalculated_table + (3 + 3) * VecTraits::num_lanes(), | |
100 | 149 | t2_2_); | |
101 | 298 | VecTraits::load(precalculated_table + (3 + 3 + 2) * VecTraits::num_lanes(), | |
102 | 149 | t3_3_); | |
103 | 149 | VecTraits::load( | |
104 | 149 | precalculated_table + (3 + 3 + 2 + 3) * VecTraits::num_lanes(), t4_2_); | |
105 | 149 | VecTraits::load( | |
106 | 149 | precalculated_table + (3 + 3 + 2 + 3 + 2) * VecTraits::num_lanes(), | |
107 | 149 | t5_3_); | |
108 | 149 | } | |
109 | 2400 | VectorType vector_path(VectorType src) { | |
110 | 2400 | VectorType dst = vqtbl3q_u8(t0_3_, src); | |
111 | 2400 | src = vsubq_u8(src, v_step3_); | |
112 | 2400 | dst = vqtbx3q_u8(dst, t1_3_, src); | |
113 | 2400 | src = vsubq_u8(src, v_step3_); | |
114 | 2400 | dst = vqtbx2q_u8(dst, t2_2_, src); | |
115 | 2400 | src = vsubq_u8(src, v_step2_); | |
116 | 2400 | dst = vqtbx3q_u8(dst, t3_3_, src); | |
117 | 2400 | src = vsubq_u8(src, v_step3_); | |
118 | 2400 | dst = vqtbx2q_u8(dst, t4_2_, src); | |
119 | 2400 | src = vsubq_u8(src, v_step2_); | |
120 | 2400 | dst = vqtbx3q_u8(dst, t5_3_, src); | |
121 | 4800 | return dst; | |
122 | 2400 | } | |
123 | |||
124 | 2816 | ScalarType scalar_path(ScalarType src) { return table_pointer_[src]; } | |
125 | |||
126 | private: | ||
127 | const ScalarType *table_pointer_; | ||
128 | 149 | Vector3Type t0_3_{}, t1_3_{}, t3_3_{}, t5_3_{}; | |
129 | 149 | Vector2Type t2_2_{}, t4_2_{}; | |
130 | VectorType v_step3_, v_step2_; | ||
131 | }; // end of class ScaleUint8Tbx<T> | ||
132 | |||
133 | // Opposite to ScaleUint8Tbx, ScaleUint8Calc is the direct approach: | ||
134 | // - calculate dst[i] = src[i] * scale + shift using vector instructions | ||
135 | class ScaleUint8Calc final : public ScaleIntBase<uint8_t> { | ||
136 | public: | ||
137 | using ScalarType = uint8_t; | ||
138 | using VecTraits = neon::VecTraits<ScalarType>; | ||
139 | using VectorType = typename VecTraits::VectorType; | ||
140 | |||
141 | 154 | ScaleUint8Calc(float scale, float shift) | |
142 | 154 | : ScaleIntBase<ScalarType>(scale, shift), | |
143 | 154 | vscale_{vdupq_n_f32(scale)}, | |
144 | 154 | vshift_{vdupq_n_f32(shift)} {} | |
145 | |||
146 | 1294 | VectorType vector_path(VectorType src) { | |
147 | // For scaling, uint8 values have to be converted to uint32 | ||
148 | // i.e. create four vectors from one | ||
149 | 1294 | uint32x4_t res11 = scale_shift(vqtbl1q_u8(src, w0)); | |
150 | 1294 | uint32x4_t res12 = scale_shift(vqtbl1q_u8(src, w1)); | |
151 | 1294 | uint32x4_t res21 = scale_shift(vqtbl1q_u8(src, w2)); | |
152 | 1294 | uint32x4_t res22 = scale_shift(vqtbl1q_u8(src, w3)); | |
153 | // Convert back from 32-bit: top two bytes are 0 for sure, unzip them | ||
154 | 2588 | uint16x8_t res1 = | |
155 | 1294 | vuzp1q_u16(vreinterpretq_u16_u32(res11), vreinterpretq_u16_u32(res12)); | |
156 | 2588 | uint16x8_t res2 = | |
157 | 1294 | vuzp1q_u16(vreinterpretq_u16_u32(res21), vreinterpretq_u16_u32(res22)); | |
158 | |||
159 | // Saturating narrowing from 16 to 8 bits | ||
160 | 2588 | return vqmovn_high_u16(vqmovn_u16(res1), res2); | |
161 | 1294 | } | |
162 | |||
163 | 2713 | ScalarType scalar_path(ScalarType src) { | |
164 | 2713 | return scale_value(src, scale_, shift_); | |
165 | } | ||
166 | |||
167 | private: | ||
168 | static constexpr ScalarType FF = std::numeric_limits<uint8_t>::max(); | ||
169 | // clang-format off | ||
170 | static constexpr uint8x16_t w0 = { 0, FF, FF, FF, 1, FF, FF, FF, 2, FF, FF, FF, 3, FF, FF, FF}; | ||
171 | static constexpr uint8x16_t w1 = { 4, FF, FF, FF, 5, FF, FF, FF, 6, FF, FF, FF, 7, FF, FF, FF}; | ||
172 | static constexpr uint8x16_t w2 = { 8, FF, FF, FF, 9, FF, FF, FF, 10, FF, FF, FF, 11, FF, FF, FF}; | ||
173 | static constexpr uint8x16_t w3 = {12, FF, FF, FF, 13, FF, FF, FF, 14, FF, FF, FF, 15, FF, FF, FF}; | ||
174 | // clang-format on | ||
175 | |||
176 | // Convert from uint32 to float32, scale and convert back with rounding | ||
177 | 5176 | inline uint32x4_t scale_shift(VectorType src) { | |
178 | 5176 | float32x4_t fx = vcvtq_f32_u32(vreinterpretq_u32_u8(src)); | |
179 | // scale + shift is done by MLA | ||
180 | 10352 | return vcvtnq_u32_f32(vmlaq_f32(vshift_, fx, vscale_)); | |
181 | 5176 | } | |
182 | |||
183 | float32x4_t vscale_, vshift_; | ||
184 | }; // end of class ScaleUint8Calc<T> | ||
185 | |||
186 | 149 | kleidicv_error_t scale_with_precalculated_table( | |
187 | const uint8_t *src, size_t src_stride, uint8_t *dst, size_t dst_stride, | ||
188 | size_t width, size_t height, float scale, float shift, | ||
189 | const std::array<uint8_t, 256> &precalculated_table) { | ||
190 | 149 | Rectangle rect{width, height}; | |
191 | 149 | Rows<const uint8_t> src_rows{src, src_stride}; | |
192 | 149 | Rows<uint8_t> dst_rows{dst, dst_stride}; | |
193 | 149 | ScaleUint8Tbx operation(scale, shift, precalculated_table.data()); | |
194 | 149 | apply_operation_by_rows(operation, rect, src_rows, dst_rows); | |
195 | |||
196 | 149 | return KLEIDICV_OK; | |
197 | 149 | } | |
198 | |||
199 | // Specialization for uint8_t | ||
200 | template <> | ||
201 | 198 | kleidicv_error_t scale(const uint8_t *src, size_t src_stride, uint8_t *dst, | |
202 | size_t dst_stride, size_t width, size_t height, | ||
203 | float scale, float shift) { | ||
204 |
4/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 195 times.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 195 times.
|
198 | CHECK_POINTER_AND_STRIDE(src, src_stride, height); |
205 |
4/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 192 times.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 192 times.
|
195 | CHECK_POINTER_AND_STRIDE(dst, dst_stride, height); |
206 |
6/6✓ Branch 0 taken 3 times.
✓ Branch 1 taken 189 times.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 186 times.
✓ Branch 4 taken 6 times.
✓ Branch 5 taken 186 times.
|
192 | CHECK_IMAGE_SIZE(width, height); |
207 | // For smaller inputs, the full calculation is the faster | ||
208 |
2/2✓ Branch 0 taken 154 times.
✓ Branch 1 taken 32 times.
|
186 | if (width * height < 675) { // empirical value |
209 | 154 | Rectangle rect{width, height}; | |
210 | 154 | Rows<const uint8_t> src_rows{src, src_stride}; | |
211 | 154 | Rows<uint8_t> dst_rows{dst, dst_stride}; | |
212 | 154 | ScaleUint8Calc operation(scale, shift); | |
213 | 154 | apply_operation_by_rows(operation, rect, src_rows, dst_rows); | |
214 | 154 | } else { | |
215 | // For bigger inputs, it's faster to pre-calculate the table | ||
216 | // and map those values during the run | ||
217 | 32 | auto precalculated_table = precalculate_scale_table_u8(scale, shift); | |
218 | 64 | return scale_with_precalculated_table(src, src_stride, dst, dst_stride, | |
219 | 32 | width, height, scale, shift, | |
220 | precalculated_table); | ||
221 | 32 | } | |
222 | 154 | return KLEIDICV_OK; | |
223 | 198 | } | |
224 | |||
225 | 7424 | static uint32x4_t scale_shift(uint32x4_t src, float scale, float shift) { | |
226 | 7424 | float32x4_t fx = vcvtq_f32_u32(src); | |
227 | 7424 | float32x4_t max = vdupq_n_f32(255.0F); | |
228 | 7424 | float32x4_t min = vdupq_n_f32(0.0F); | |
229 | 7424 | float32x4_t val = vmlaq_f32(vdupq_n_f32(shift), fx, vdupq_n_f32(scale)); | |
230 | 14848 | return vcvtnq_u32_f32(vmaxq_f32(min, vminq_f32(val, max))); | |
231 | 7424 | } | |
232 | |||
233 | 116 | std::array<uint8_t, 256> precalculate_scale_table_u8(float scale, float shift) { | |
234 | static constexpr size_t TableLength = 256; | ||
235 | 116 | std::array<uint8_t, TableLength> precalculated_table{}; | |
236 | |||
237 | 116 | uint32x4_t counter = {0, 1, 2, 3}; | |
238 | 116 | uint32x4_t four = vdupq_n_u32(4); | |
239 | |||
240 |
2/2✓ Branch 0 taken 116 times.
✓ Branch 1 taken 1856 times.
|
1972 | for (size_t i = 0; i < TableLength; i += 16) { |
241 | 1856 | uint32x4_t res11 = scale_shift(counter, scale, shift); | |
242 | 1856 | counter = vaddq(counter, four); | |
243 | 1856 | uint32x4_t res12 = scale_shift(counter, scale, shift); | |
244 | 1856 | counter = vaddq(counter, four); | |
245 | 1856 | uint32x4_t res21 = scale_shift(counter, scale, shift); | |
246 | 1856 | counter = vaddq(counter, four); | |
247 | 1856 | uint32x4_t res22 = scale_shift(counter, scale, shift); | |
248 | 1856 | counter = vaddq(counter, four); | |
249 | |||
250 | 3712 | uint16x8_t res1 = | |
251 | 1856 | vuzp1q_u16(vreinterpretq_u16_u32(res11), vreinterpretq_u16_u32(res12)); | |
252 | 3712 | uint16x8_t res2 = | |
253 | 1856 | vuzp1q_u16(vreinterpretq_u16_u32(res21), vreinterpretq_u16_u32(res22)); | |
254 | // Saturating narrowing from 16 to 8 bits | ||
255 | 1856 | uint8x16_t res = vqmovn_high_u16(vqmovn_u16(res1), res2); | |
256 | |||
257 | 1856 | vst1q_u8(&precalculated_table[i], res); | |
258 | 1856 | } | |
259 | return precalculated_table; | ||
260 | 116 | } | |
261 | |||
262 | // ----------------------------------------------------------------------- | ||
263 | // Float implementation | ||
264 | // ----------------------------------------------------------------------- | ||
265 | |||
266 | class AddFloat final : public UnrollTwice, public UnrollOnce { | ||
267 | public: | ||
268 | using ScalarType = float; | ||
269 | using VecTraits = neon::VecTraits<ScalarType>; | ||
270 | using VectorType = typename VecTraits::VectorType; | ||
271 | |||
272 | 6 | explicit AddFloat(float shift) : shift_{shift}, vshift_{vdupq_n_f32(shift)} {} | |
273 | |||
274 | 5034 | VectorType vector_path(VectorType src) { return vaddq_f32(vshift_, src); } | |
275 | |||
276 | // NOLINTBEGIN(readability-make-member-function-const) | ||
277 | 12 | ScalarType scalar_path(ScalarType src) { return src + shift_; } | |
278 | // NOLINTEND(readability-make-member-function-const) | ||
279 | |||
280 | private: | ||
281 | float shift_; | ||
282 | float32x4_t vshift_; | ||
283 | }; // end of class AddFloat | ||
284 | |||
285 | class ScaleFloat final : public UnrollTwice, public UnrollOnce { | ||
286 | public: | ||
287 | using ScalarType = float; | ||
288 | using VecTraits = neon::VecTraits<ScalarType>; | ||
289 | using VectorType = typename VecTraits::VectorType; | ||
290 | |||
291 | 101 | ScaleFloat(float scale, float shift) | |
292 | 101 | : scale_{scale}, | |
293 | 101 | shift_{shift}, | |
294 | 101 | vscale_{vdupq_n_f32(scale)}, | |
295 | 101 | vshift_{vdupq_n_f32(shift)} {} | |
296 | |||
297 | 45897 | VectorType vector_path(VectorType src) { | |
298 | 45897 | return vmlaq_f32(vshift_, src, vscale_); | |
299 | } | ||
300 | |||
301 | // NOLINTBEGIN(readability-make-member-function-const) | ||
302 | 4083 | ScalarType scalar_path(ScalarType src) { return src * scale_ + shift_; } | |
303 | // NOLINTEND(readability-make-member-function-const) | ||
304 | |||
305 | private: | ||
306 | float scale_, shift_; | ||
307 | float32x4_t vscale_, vshift_; | ||
308 | }; // end of class ScaleFloat | ||
309 | |||
310 | // Specialization for float | ||
311 | template <> | ||
312 | 113 | kleidicv_error_t scale(const float *src, size_t src_stride, float *dst, | |
313 | size_t dst_stride, size_t width, size_t height, | ||
314 | float scale, float shift) { | ||
315 |
4/4✓ Branch 0 taken 2 times.
✓ Branch 1 taken 111 times.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 111 times.
|
113 | CHECK_POINTER_AND_STRIDE(src, src_stride, height); |
316 |
4/4✓ Branch 0 taken 2 times.
✓ Branch 1 taken 109 times.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 109 times.
|
111 | CHECK_POINTER_AND_STRIDE(dst, dst_stride, height); |
317 |
6/6✓ Branch 0 taken 1 times.
✓ Branch 1 taken 108 times.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 107 times.
✓ Branch 4 taken 2 times.
✓ Branch 5 taken 107 times.
|
109 | CHECK_IMAGE_SIZE(width, height); |
318 | |||
319 | 107 | Rectangle rect{width, height}; | |
320 | 107 | Rows<const float> src_rows{src, src_stride}; | |
321 | 107 | Rows<float> dst_rows{dst, dst_stride}; | |
322 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 101 times.
|
107 | if (scale == 1.0) { |
323 | 6 | AddFloat operation(shift); | |
324 | 6 | apply_operation_by_rows(operation, rect, src_rows, dst_rows); | |
325 | 6 | } else { | |
326 | 101 | ScaleFloat operation(scale, shift); | |
327 | 101 | apply_operation_by_rows(operation, rect, src_rows, dst_rows); | |
328 | 101 | } | |
329 | 107 | return KLEIDICV_OK; | |
330 | 113 | } | |
331 | |||
332 | } // namespace kleidicv::neon | ||
333 |