Skip to content

Commit 9555ee7

Browse files
committed
cpu: aarch64: add ASIMD softmax JIT implementation
This commit introduces an f32 ASIMD `softmax` JIT implementation using the `exp` eltwise injector added in #4376, while also improving performance for the existing `sve_*` implementations (primarily by increasing the unrolling factor `unroll_regs_` and skipping the multiplication with default dequantization/requantization factors `src_scales` /`dst_scales`). For `jit:asimd` and `jit:sve_128`, the `exp` function is also effectively inlined by setting `preserve_vmm = false`, whereas `jit:sve_256` did not benefit from such a change. As the previous softmax implementation heavily relied on predicated instructions, `jit_softmax_base_t` was refactored to only include common logic for SVE and non-SVE implementations alike. At the same time, two different derived constructs were added to handle ISA-specific work: `jit_softmax_sve_t` and `jit_softmax_asimd_t`. In addition, the JIT eltwise injector was changed to support storing/loading preserved vectors on non-SVE targets.
1 parent 8b58b7e commit 9555ee7

File tree

5 files changed

+824
-510
lines changed

5 files changed

+824
-510
lines changed

src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp

Lines changed: 99 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ bool is_supported(cpu_isa_t isa, alg_kind_t alg) {
7272

7373
using namespace Xbyak_aarch64;
7474

75+
template <cpu_isa_t isa>
76+
void jit_uni_eltwise_injector_t<isa>::set_input_range(
77+
float min_value, float max_value) {
78+
min_input_ = min_value;
79+
max_input_ = max_value;
80+
}
81+
7582
template <cpu_isa_t isa>
7683
void jit_uni_eltwise_injector_t<isa>::injector_preamble(
7784
const injector_utils::vmm_index_set_t &vmm_idxs) {
@@ -120,7 +127,7 @@ void jit_uni_eltwise_injector_t<isa>::injector_preamble(
120127
h->sub_imm(h->X_SP, h->X_SP, preserved_vecs_count * vlen,
121128
h->X_TMP_0);
122129
for (size_t i = 0; i < preserved_vecs_count; ++i)
123-
h->str(ZReg(preserved_vec_idxs[i]), ptr(h->X_SP, i, MUL_VL));
130+
store_preserved_vec(i, preserved_vec_idxs[i]);
124131
}
125132
load_table_addr();
126133
}
@@ -141,17 +148,15 @@ void jit_uni_eltwise_injector_t<isa>::injector_preamble_tail(
141148
if (idx_off) h->add_imm(h->X_SP, h->X_SP, idx_off * vlen, h->X_TMP_0);
142149

143150
for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
144-
h->ldr(ZReg(preserved_vec_idxs[idx_off + i]),
145-
ptr(h->X_SP, i, MUL_VL));
151+
load_preserved_vec(i, preserved_vec_idxs[idx_off + i]);
146152
}
147153

148154
for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
149155
preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
150156

151157
if (save_state_ && preserve_vmm_) {
152158
for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
153-
h->str(ZReg(preserved_vec_idxs[idx_off + i]),
154-
ptr(h->X_SP, i, MUL_VL));
159+
store_preserved_vec(i, preserved_vec_idxs[idx_off + i]);
155160

156161
if (idx_off) h->sub_imm(h->X_SP, h->X_SP, idx_off * vlen, h->X_TMP_0);
157162
}
@@ -168,7 +173,7 @@ void jit_uni_eltwise_injector_t<isa>::injector_postamble() {
168173

169174
if (preserve_vmm_) {
170175
for (size_t i = 0; i < preserved_vecs_count; ++i)
171-
h->ldr(ZReg(preserved_vec_idxs[i]), ptr(h->X_SP, i, MUL_VL));
176+
load_preserved_vec(i, preserved_vec_idxs[i]);
172177

173178
if (preserved_vecs_count)
174179
h->add_imm(
@@ -196,6 +201,26 @@ void jit_uni_eltwise_injector_t<isa>::assign_regs() {
196201
vmm_aux7 = TRegS(preserved_vec_idxs[8]);
197202
}
198203

204+
template <cpu_isa_t isa>
205+
inline void jit_uni_eltwise_injector_t<isa>::store_preserved_vec(
206+
size_t slot, size_t vmm_idx) {
207+
if (isa == asimd) {
208+
h->str(QReg(vmm_idx), ptr(h->X_SP, static_cast<int32_t>(slot * vlen)));
209+
} else {
210+
h->str(ZReg(vmm_idx), ptr(h->X_SP, slot, MUL_VL));
211+
}
212+
}
213+
214+
template <cpu_isa_t isa>
215+
inline void jit_uni_eltwise_injector_t<isa>::load_preserved_vec(
216+
size_t slot, size_t vmm_idx) {
217+
if (isa == asimd) {
218+
h->ldr(QReg(vmm_idx), ptr(h->X_SP, static_cast<int32_t>(slot * vlen)));
219+
} else {
220+
h->ldr(ZReg(vmm_idx), ptr(h->X_SP, slot, MUL_VL));
221+
}
222+
}
223+
199224
template <cpu_isa_t isa>
200225
void jit_uni_eltwise_injector_t<isa>::set_coef_to_regs() {
201226
using namespace alg_kind;
@@ -434,8 +459,8 @@ void jit_uni_eltwise_injector_t<isa>::exp_compute_vector_fwd(
434459
const TRegS &vmm_src) {
435460

436461
const auto &t0 = ZRegS(IDX(vmm_src));
437-
const auto &t1 = ZRegS(IDX(vmm_aux1));
438-
const auto &t2 = ZRegS(IDX(vmm_aux2));
462+
const auto &t1 = ZRegS(IDX(vmm_aux0));
463+
const auto &t2 = ZRegS(IDX(vmm_aux1));
439464
h->fmin(t0, p_all, ZRegS(IDX(table_val(exp_ln_flt_max_f, z_tmp))));
440465
h->fmax(t0, p_all, ZRegS(IDX(table_val(exp_ln_flt_min_f, z_tmp))));
441466
h->fmul(t0, t0, ZRegS(IDX(table_val(exp_log2ef, z_tmp))));
@@ -1591,7 +1616,7 @@ size_t jit_uni_eltwise_injector_t<isa>::aux_vecs_count() {
15911616
case eltwise_logistic_use_dst_for_bwd:
15921617
case eltwise_logistic: return 5; /* = exp + 1 */
15931618
case eltwise_exp_use_dst_for_bwd:
1594-
case eltwise_exp: return (isa == asimd) ? 5 : 4;
1619+
case eltwise_exp: return (isa == asimd) ? 5 : 3;
15951620
case eltwise_gelu_tanh: return 9; /* = tanh */
15961621
case eltwise_swish: return 6; /* = logistic */
15971622
case eltwise_log: return 6;
@@ -2506,13 +2531,24 @@ void jit_uni_eltwise_injector_t<asimd>::exp_compute_vector_fwd(
25062531
* For very large |n| > 196, use exp(x) = s1*s1.
25072532
*/
25082533

2534+
Xbyak_aarch64::Label L_done, L_special;
25092535
const auto &t0 = VReg4S(vmm_src.getIdx());
25102536
const auto &t1 = VReg4S(vmm_aux0.getIdx());
25112537
const auto &t2 = VReg4S(vmm_aux1.getIdx());
25122538
const auto &t3 = VReg4S(vmm_aux2.getIdx());
25132539
const auto &t4 = VReg4S(vmm_aux3.getIdx());
25142540
const auto &t_tmp = VReg4S(vmm_tmp.getIdx());
25152541

2542+
const float special_bound_input = 126.5f * logf(2.0f);
2543+
const float ln_flt_min = logf(FLT_MIN);
2544+
bool need_clamp = min_input_ < ln_flt_min;
2545+
bool need_special_case = max_input_ >= special_bound_input;
2546+
2547+
if (!need_special_case && need_clamp) {
2548+
// Clamp x to avoid overflow of f32 exponent bits
2549+
h->fmax(t0, t0, table_val(exp_ln_flt_min_f, t4));
2550+
}
2551+
25162552
// z = x * inv_ln2
25172553
const auto &v_inv_ln2 = table_val(exp_log2ef, t_tmp);
25182554
const auto &v_src = t0;
@@ -2565,15 +2601,17 @@ void jit_uni_eltwise_injector_t<asimd>::exp_compute_vector_fwd(
25652601
h->fmul(v_p, v_c4, v_r);
25662602
h->fmla(v_p, v_q, v_r2);
25672603

2568-
// Check if any lane needs special-case handling
2569-
// mask_special = (|n| > 126)
2570-
Xbyak_aarch64::Label L_done, L_special;
2571-
const auto &v_mask_special = v_c3;
2572-
const auto &v_special_bound = table_val(exp_special_bound, t_tmp); // 126.0f
2573-
h->facgt(v_mask_special, v_n, v_special_bound);
2574-
h->addp(DReg(t_tmp.getIdx()), VReg2D(v_mask_special.getIdx()));
2575-
h->fmov(h->X_TMP_0, DReg(t_tmp.getIdx()));
2576-
h->cbnz(h->X_TMP_0, L_special);
2604+
if (need_special_case) {
2605+
// Check if any lane needs special-case handling
2606+
// mask_special = (|n| > 126)
2607+
const auto &v_mask_special = v_c3;
2608+
const auto &v_special_bound
2609+
= table_val(exp_special_bound, t_tmp); // 126.0f
2610+
h->facgt(v_mask_special, v_n, v_special_bound);
2611+
h->addp(DReg(t_tmp.getIdx()), VReg2D(v_mask_special.getIdx()));
2612+
h->fmov(h->X_TMP_0, DReg(t_tmp.getIdx()));
2613+
h->cbnz(h->X_TMP_0, L_special);
2614+
}
25772615

25782616
// ===== Fast path =====
25792617
// scale = reinterpret_f32(e + 0x3f800000) = 2^n
@@ -2585,47 +2623,49 @@ void jit_uni_eltwise_injector_t<asimd>::exp_compute_vector_fwd(
25852623
// exp(x) = scale + poly * scale = 2^n * (1 + poly(r))
25862624
const auto &v_dst = v_src;
25872625
h->fmla(v_dst, v_p, v_exp_scale);
2588-
h->b(L_done);
2589-
2590-
// ===== Special-case handling =====
2591-
// b = (n <= 0) ? special_offset : 0
2592-
h->L(L_special);
2593-
const auto &v_b = v_q;
2594-
const auto &v_special_offset
2595-
= table_val(exp_special_offset, t_tmp); // 0x82000000
2596-
h->fcmle(v_b, v_n, 0.0);
2597-
h->and_(VReg16B(v_b.getIdx()), VReg16B(v_b.getIdx()),
2598-
VReg16B(v_special_offset.getIdx()));
2599-
2600-
// mask_thresh = (|n| > 192)
2601-
const auto &v_thresh = table_val(exp_scale_thresh, t_tmp); // 192.0f
2602-
const auto &v_mask_thresh = v_n;
2603-
h->facgt(v_mask_thresh, v_n, v_thresh);
2604-
2605-
// s2_bits = e - b
2606-
const auto &v_s2 = v_e;
2607-
h->sub(v_s2, v_e, v_b);
2608-
2609-
// s1_bits = b + special_bias
2610-
const auto &v_s1 = v_b;
2611-
const auto &v_special_bias
2612-
= table_val(exp_special_bias, t_tmp); // 0x7f000000
2613-
h->add(v_s1, v_b, v_special_bias);
2614-
2615-
// r0 = (s2 + poly*s2) * s1
2616-
const auto &v_r0 = v_p;
2617-
h->fmla(v_s2, v_p, v_s2);
2618-
h->fmul(v_r0, v_s2, v_s1);
2619-
2620-
// r1 = s1 * s1
2621-
const auto &v_r1 = v_dst;
2622-
h->fmul(v_r1, v_s1, v_s1);
2623-
2624-
// out_special = (|n| > 192) ? r1 : r0
2625-
h->bif(VReg16B(v_r1.getIdx()), VReg16B(v_r0.getIdx()),
2626-
VReg16B(v_mask_thresh.getIdx()));
2627-
2628-
h->L(L_done);
2626+
2627+
if (need_special_case) {
2628+
// ===== Special-case handling =====
2629+
// b = (n <= 0) ? special_offset : 0
2630+
h->b(L_done);
2631+
h->L(L_special);
2632+
const auto &v_b = v_q;
2633+
const auto &v_special_offset
2634+
= table_val(exp_special_offset, t_tmp); // 0x82000000
2635+
h->fcmle(v_b, v_n, 0.0);
2636+
h->and_(VReg16B(v_b.getIdx()), VReg16B(v_b.getIdx()),
2637+
VReg16B(v_special_offset.getIdx()));
2638+
2639+
// mask_thresh = (|n| > 192)
2640+
const auto &v_thresh = table_val(exp_scale_thresh, t_tmp); // 192.0f
2641+
const auto &v_mask_thresh = v_n;
2642+
h->facgt(v_mask_thresh, v_n, v_thresh);
2643+
2644+
// s2_bits = e - b
2645+
const auto &v_s2 = v_e;
2646+
h->sub(v_s2, v_e, v_b);
2647+
2648+
// s1_bits = b + special_bias
2649+
const auto &v_s1 = v_b;
2650+
const auto &v_special_bias
2651+
= table_val(exp_special_bias, t_tmp); // 0x7f000000
2652+
h->add(v_s1, v_b, v_special_bias);
2653+
2654+
// r0 = (s2 + poly*s2) * s1
2655+
const auto &v_r0 = v_p;
2656+
h->fmla(v_s2, v_p, v_s2);
2657+
h->fmul(v_r0, v_s2, v_s1);
2658+
2659+
// r1 = s1 * s1
2660+
const auto &v_r1 = v_dst;
2661+
h->fmul(v_r1, v_s1, v_s1);
2662+
2663+
// out_special = (|n| > 192) ? r1 : r0
2664+
h->bif(VReg16B(v_r1.getIdx()), VReg16B(v_r0.getIdx()),
2665+
VReg16B(v_mask_thresh.getIdx()));
2666+
2667+
h->L(L_done);
2668+
}
26292669
}
26302670

26312671
template <>

src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,15 @@ struct jit_uni_eltwise_injector_t {
144144
void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
145145
void prepare_table(bool gen_table = true);
146146
void load_table_addr() { h->adr(x_table, l_table); }
147+
void set_input_range(float min_value, float max_value);
147148

148149
private:
149150
const alg_kind_t alg_;
150151
const float alpha_;
151152
const float beta_;
152153
const float scale_;
154+
float max_input_ = INFINITY;
155+
float min_input_ = -INFINITY;
153156

154157
jit_generator_t *const h;
155158

@@ -212,6 +215,8 @@ struct jit_uni_eltwise_injector_t {
212215
const injector_utils::vmm_index_set_iterator_t start_idx_it);
213216
void injector_postamble();
214217
void assign_regs();
218+
void store_preserved_vec(size_t slot, size_t vmm_idx);
219+
void load_preserved_vec(size_t slot, size_t vmm_idx);
215220
void set_coef_to_regs();
216221
void compute_cmp_mask(
217222
const TRegS &vmm_src, const TRegS &vmm_cmpare, int cmp_predicate);

0 commit comments

Comments
 (0)