@@ -72,6 +72,13 @@ bool is_supported(cpu_isa_t isa, alg_kind_t alg) {
7272
7373using namespace Xbyak_aarch64 ;
7474
75+ template <cpu_isa_t isa>
76+ void jit_uni_eltwise_injector_t <isa>::set_input_range(
77+ float min_value, float max_value) {
78+ min_input_ = min_value;
79+ max_input_ = max_value;
80+ }
81+
7582template <cpu_isa_t isa>
7683void jit_uni_eltwise_injector_t <isa>::injector_preamble(
7784 const injector_utils::vmm_index_set_t &vmm_idxs) {
@@ -120,7 +127,7 @@ void jit_uni_eltwise_injector_t<isa>::injector_preamble(
120127 h->sub_imm (h->X_SP , h->X_SP , preserved_vecs_count * vlen,
121128 h->X_TMP_0 );
122129 for (size_t i = 0 ; i < preserved_vecs_count; ++i)
123- h-> str ( ZReg (preserved_vec_idxs[i]), ptr (h-> X_SP , i, MUL_VL) );
130+ store_preserved_vec (i, preserved_vec_idxs[i] );
124131 }
125132 load_table_addr ();
126133 }
@@ -141,17 +148,15 @@ void jit_uni_eltwise_injector_t<isa>::injector_preamble_tail(
141148 if (idx_off) h->add_imm (h->X_SP , h->X_SP , idx_off * vlen, h->X_TMP_0 );
142149
143150 for (size_t i = 0 ; i < tail_vecs_to_preserve; ++i)
144- h->ldr (ZReg (preserved_vec_idxs[idx_off + i]),
145- ptr (h->X_SP , i, MUL_VL));
151+ load_preserved_vec (i, preserved_vec_idxs[idx_off + i]);
146152 }
147153
148154 for (size_t i = 0 ; i < tail_vecs_to_preserve; ++i)
149155 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
150156
151157 if (save_state_ && preserve_vmm_) {
152158 for (size_t i = 0 ; i < tail_vecs_to_preserve; ++i)
153- h->str (ZReg (preserved_vec_idxs[idx_off + i]),
154- ptr (h->X_SP , i, MUL_VL));
159+ store_preserved_vec (i, preserved_vec_idxs[idx_off + i]);
155160
156161 if (idx_off) h->sub_imm (h->X_SP , h->X_SP , idx_off * vlen, h->X_TMP_0 );
157162 }
@@ -168,7 +173,7 @@ void jit_uni_eltwise_injector_t<isa>::injector_postamble() {
168173
169174 if (preserve_vmm_) {
170175 for (size_t i = 0 ; i < preserved_vecs_count; ++i)
171- h-> ldr ( ZReg (preserved_vec_idxs[i]), ptr (h-> X_SP , i, MUL_VL) );
176+ load_preserved_vec (i, preserved_vec_idxs[i] );
172177
173178 if (preserved_vecs_count)
174179 h->add_imm (
@@ -196,6 +201,26 @@ void jit_uni_eltwise_injector_t<isa>::assign_regs() {
196201 vmm_aux7 = TRegS (preserved_vec_idxs[8 ]);
197202}
198203
204+ template <cpu_isa_t isa>
205+ inline void jit_uni_eltwise_injector_t <isa>::store_preserved_vec(
206+ size_t slot, size_t vmm_idx) {
207+ if (isa == asimd) {
208+ h->str (QReg (vmm_idx), ptr (h->X_SP , static_cast <int32_t >(slot * vlen)));
209+ } else {
210+ h->str (ZReg (vmm_idx), ptr (h->X_SP , slot, MUL_VL));
211+ }
212+ }
213+
214+ template <cpu_isa_t isa>
215+ inline void jit_uni_eltwise_injector_t <isa>::load_preserved_vec(
216+ size_t slot, size_t vmm_idx) {
217+ if (isa == asimd) {
218+ h->ldr (QReg (vmm_idx), ptr (h->X_SP , static_cast <int32_t >(slot * vlen)));
219+ } else {
220+ h->ldr (ZReg (vmm_idx), ptr (h->X_SP , slot, MUL_VL));
221+ }
222+ }
223+
199224template <cpu_isa_t isa>
200225void jit_uni_eltwise_injector_t <isa>::set_coef_to_regs() {
201226 using namespace alg_kind ;
@@ -434,8 +459,8 @@ void jit_uni_eltwise_injector_t<isa>::exp_compute_vector_fwd(
434459 const TRegS &vmm_src) {
435460
436461 const auto &t0 = ZRegS (IDX (vmm_src));
437- const auto &t1 = ZRegS (IDX (vmm_aux1 ));
438- const auto &t2 = ZRegS (IDX (vmm_aux2 ));
462+ const auto &t1 = ZRegS (IDX (vmm_aux0 ));
463+ const auto &t2 = ZRegS (IDX (vmm_aux1 ));
439464 h->fmin (t0, p_all, ZRegS (IDX (table_val (exp_ln_flt_max_f, z_tmp))));
440465 h->fmax (t0, p_all, ZRegS (IDX (table_val (exp_ln_flt_min_f, z_tmp))));
441466 h->fmul (t0, t0, ZRegS (IDX (table_val (exp_log2ef, z_tmp))));
@@ -1591,7 +1616,7 @@ size_t jit_uni_eltwise_injector_t<isa>::aux_vecs_count() {
15911616 case eltwise_logistic_use_dst_for_bwd:
15921617 case eltwise_logistic: return 5 ; /* = exp + 1 */
15931618 case eltwise_exp_use_dst_for_bwd:
1594- case eltwise_exp: return (isa == asimd) ? 5 : 4 ;
1619+ case eltwise_exp: return (isa == asimd) ? 5 : 3 ;
15951620 case eltwise_gelu_tanh: return 9 ; /* = tanh */
15961621 case eltwise_swish: return 6 ; /* = logistic */
15971622 case eltwise_log: return 6 ;
@@ -2506,13 +2531,24 @@ void jit_uni_eltwise_injector_t<asimd>::exp_compute_vector_fwd(
25062531 * For very large |n| > 196, use exp(x) = s1*s1.
25072532 */
25082533
2534+ Xbyak_aarch64::Label L_done, L_special;
25092535 const auto &t0 = VReg4S (vmm_src.getIdx ());
25102536 const auto &t1 = VReg4S (vmm_aux0.getIdx ());
25112537 const auto &t2 = VReg4S (vmm_aux1.getIdx ());
25122538 const auto &t3 = VReg4S (vmm_aux2.getIdx ());
25132539 const auto &t4 = VReg4S (vmm_aux3.getIdx ());
25142540 const auto &t_tmp = VReg4S (vmm_tmp.getIdx ());
25152541
2542+ const float special_bound_input = 126 .5f * logf (2 .0f );
2543+ const float ln_flt_min = logf (FLT_MIN);
2544+ bool need_clamp = min_input_ < ln_flt_min;
2545+ bool need_special_case = max_input_ >= special_bound_input;
2546+
2547+ if (!need_special_case && need_clamp) {
2548+ // Clamp x to avoid overflow of f32 exponent bits
2549+ h->fmax (t0, t0, table_val (exp_ln_flt_min_f, t4));
2550+ }
2551+
25162552 // z = x * inv_ln2
25172553 const auto &v_inv_ln2 = table_val (exp_log2ef, t_tmp);
25182554 const auto &v_src = t0;
@@ -2565,15 +2601,17 @@ void jit_uni_eltwise_injector_t<asimd>::exp_compute_vector_fwd(
25652601 h->fmul (v_p, v_c4, v_r);
25662602 h->fmla (v_p, v_q, v_r2);
25672603
2568- // Check if any lane needs special-case handling
2569- // mask_special = (|n| > 126)
2570- Xbyak_aarch64::Label L_done, L_special;
2571- const auto &v_mask_special = v_c3;
2572- const auto &v_special_bound = table_val (exp_special_bound, t_tmp); // 126.0f
2573- h->facgt (v_mask_special, v_n, v_special_bound);
2574- h->addp (DReg (t_tmp.getIdx ()), VReg2D (v_mask_special.getIdx ()));
2575- h->fmov (h->X_TMP_0 , DReg (t_tmp.getIdx ()));
2576- h->cbnz (h->X_TMP_0 , L_special);
2604+ if (need_special_case) {
2605+ // Check if any lane needs special-case handling
2606+ // mask_special = (|n| > 126)
2607+ const auto &v_mask_special = v_c3;
2608+ const auto &v_special_bound
2609+ = table_val (exp_special_bound, t_tmp); // 126.0f
2610+ h->facgt (v_mask_special, v_n, v_special_bound);
2611+ h->addp (DReg (t_tmp.getIdx ()), VReg2D (v_mask_special.getIdx ()));
2612+ h->fmov (h->X_TMP_0 , DReg (t_tmp.getIdx ()));
2613+ h->cbnz (h->X_TMP_0 , L_special);
2614+ }
25772615
25782616 // ===== Fast path =====
25792617 // scale = reinterpret_f32(e + 0x3f800000) = 2^n
@@ -2585,47 +2623,49 @@ void jit_uni_eltwise_injector_t<asimd>::exp_compute_vector_fwd(
25852623 // exp(x) = scale + poly * scale = 2^n * (1 + poly(r))
25862624 const auto &v_dst = v_src;
25872625 h->fmla (v_dst, v_p, v_exp_scale);
2588- h->b (L_done);
2589-
2590- // ===== Special-case handling =====
2591- // b = (n <= 0) ? special_offset : 0
2592- h->L (L_special);
2593- const auto &v_b = v_q;
2594- const auto &v_special_offset
2595- = table_val (exp_special_offset, t_tmp); // 0x82000000
2596- h->fcmle (v_b, v_n, 0.0 );
2597- h->and_ (VReg16B (v_b.getIdx ()), VReg16B (v_b.getIdx ()),
2598- VReg16B (v_special_offset.getIdx ()));
2599-
2600- // mask_thresh = (|n| > 192)
2601- const auto &v_thresh = table_val (exp_scale_thresh, t_tmp); // 192.0f
2602- const auto &v_mask_thresh = v_n;
2603- h->facgt (v_mask_thresh, v_n, v_thresh);
2604-
2605- // s2_bits = e - b
2606- const auto &v_s2 = v_e;
2607- h->sub (v_s2, v_e, v_b);
2608-
2609- // s1_bits = b + special_bias
2610- const auto &v_s1 = v_b;
2611- const auto &v_special_bias
2612- = table_val (exp_special_bias, t_tmp); // 0x7f000000
2613- h->add (v_s1, v_b, v_special_bias);
2614-
2615- // r0 = (s2 + poly*s2) * s1
2616- const auto &v_r0 = v_p;
2617- h->fmla (v_s2, v_p, v_s2);
2618- h->fmul (v_r0, v_s2, v_s1);
2619-
2620- // r1 = s1 * s1
2621- const auto &v_r1 = v_dst;
2622- h->fmul (v_r1, v_s1, v_s1);
2623-
2624- // out_special = (|n| > 192) ? r1 : r0
2625- h->bif (VReg16B (v_r1.getIdx ()), VReg16B (v_r0.getIdx ()),
2626- VReg16B (v_mask_thresh.getIdx ()));
2627-
2628- h->L (L_done);
2626+
2627+ if (need_special_case) {
2628+ // ===== Special-case handling =====
2629+ // b = (n <= 0) ? special_offset : 0
2630+ h->b (L_done);
2631+ h->L (L_special);
2632+ const auto &v_b = v_q;
2633+ const auto &v_special_offset
2634+ = table_val (exp_special_offset, t_tmp); // 0x82000000
2635+ h->fcmle (v_b, v_n, 0.0 );
2636+ h->and_ (VReg16B (v_b.getIdx ()), VReg16B (v_b.getIdx ()),
2637+ VReg16B (v_special_offset.getIdx ()));
2638+
2639+ // mask_thresh = (|n| > 192)
2640+ const auto &v_thresh = table_val (exp_scale_thresh, t_tmp); // 192.0f
2641+ const auto &v_mask_thresh = v_n;
2642+ h->facgt (v_mask_thresh, v_n, v_thresh);
2643+
2644+ // s2_bits = e - b
2645+ const auto &v_s2 = v_e;
2646+ h->sub (v_s2, v_e, v_b);
2647+
2648+ // s1_bits = b + special_bias
2649+ const auto &v_s1 = v_b;
2650+ const auto &v_special_bias
2651+ = table_val (exp_special_bias, t_tmp); // 0x7f000000
2652+ h->add (v_s1, v_b, v_special_bias);
2653+
2654+ // r0 = (s2 + poly*s2) * s1
2655+ const auto &v_r0 = v_p;
2656+ h->fmla (v_s2, v_p, v_s2);
2657+ h->fmul (v_r0, v_s2, v_s1);
2658+
2659+ // r1 = s1 * s1
2660+ const auto &v_r1 = v_dst;
2661+ h->fmul (v_r1, v_s1, v_s1);
2662+
2663+ // out_special = (|n| > 192) ? r1 : r0
2664+ h->bif (VReg16B (v_r1.getIdx ()), VReg16B (v_r0.getIdx ()),
2665+ VReg16B (v_mask_thresh.getIdx ()));
2666+
2667+ h->L (L_done);
2668+ }
26292669}
26302670
26312671template <>
0 commit comments