| // Copyright 2022 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| $assert REQUANTIZATION in ["FP32", "RNDNU"] |
| $assert not CHANNELWISE or REQUANTIZATION == "FP32" |
| $assert DATATYPE in ["QC8", "QS8"] |
| $assert DATATYPE != "QC8" or REQUANTIZATION == "FP32" |
| |
| #include <xnnpack/assembly.h> |
| |
| .syntax unified |
| |
| $PARAMS_UNION = "xnn_qs8_minmax_params" if CHANNELWISE else "xnn_qs8_conv_minmax_params" |
| # LINT.IfChange |
| // void xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x8c4__aarch32_neondot_ld64( |
| // size_t mr, r0 |
| // size_t nc, r1 |
| // size_t kc, r2 -> r5 -> sp + 52 |
| // size_t ks, r3 -> sp + 56 -> r14 |
| // const int8_t**restrict a, sp + 96 -> r2 |
| // const void*restrict w, sp + 100 -> r9 |
| // int8_t*restrict c, sp + 104 -> r11 |
| // size_t cm_stride, sp + 108 -> (r6) |
| // size_t cn_stride, sp + 112 -> (r7) |
| // size_t a_offset, sp + 116 -> (r5) |
| // const int8_t* zero, sp + 120 -> (r7) |
| // ${PARAMS_UNION}*params); sp + 124 -> (r5) |
| |
| // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved. |
| |
| // Register usage |
| // A0 r3 d0 |
| // A1 r12 d1 |
| // A2 r10 d2 |
| // A3 r0 d3 |
| |
| // B r9 q2 q3 q4 q5 |
| |
| // C0 r11 d16-d17 q8 d18-d19 q9 |
| // C1 r4 d20-d21 q10 d22-d23 q11 |
| // C2 r8 d24-d25 q12 d26-d27 q13 |
| // C3 r6 d28-d29 q14 d30-d31 q15 |
| |
| // unused q7 |
| |
| $if REQUANTIZATION == "RNDNU": |
| // params structure is 16 bytes |
| // struct { |
| // int32_t right_pre_shift; d12[0] |
| // int32_t multiplier; d12[1] |
| // int32_t right_post_shift; d13[0] |
| // int16_t output_zero_point; d13[2] |
| // int8_t output_min; d13[6] |
| // int8_t output_max; d13[7] |
| // } rndnu_neon; |
| $else: |
| // params structure is 4 bytes |
| // struct { |
| // int16_t output_zero_point; d13[2] |
| // int8_t output_min; d13[6] |
| // int8_t output_max; d13[7] |
| // } xnn_qs8_minmax_params.neonv8; |
| |
| // iOS does not support 32 bit ARM with Neon DotProduct. |
| #ifndef __APPLE__ |
| BEGIN_FUNCTION xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x8c4__aarch32_neondot_ld64 |
| ADD r2, r2, 3 // kc = (kc + 3) & ~3 |
| BIC r2, r2, 3 |
| # Push 96 bytes |
| # r2 will be reloaded in outer loop. r3 is ks |
| PUSH {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, lr} // +44 |
| SUB sp, sp, 4 // 4 |
| VPUSH {d8-d13} // +48 = 96 |
| |
| LDR r11, [sp, 104] // c |
| LDR r6, [sp, 108] // cm_stride |
| LDR r2, [sp, 96] // a |
| LDR r9, [sp, 100] // w |
| LDR r5, [sp, 124] // params |
| MOV r14, r3 // p = ks |
| |
| # Clamp C pointers |
| CMP r0, 2 // if mr >= 2 |
| ADD r4, r11, r6 // c1 = c0 + cm_stride |
| MOVLO r4, r11 // c1 |
| // if mr > 2 |
| ADD r8, r4, r6 // c2 = c1 + cm_stride |
| MOVLS r8, r4 // c2 |
| CMP r0, 4 // if mr >=4 |
| ADD r6, r8, r6 // c3 = c2 + cm_stride |
| MOVLO r6, r8 // c3 |
| |
| # Load params values |
| $if REQUANTIZATION == "RNDNU": |
| VLDM r5, {d12-d13} // RNDNU params |
| $else: |
| VLD1.32 {d13[]}, [r5] // QC8 params |
| |
| 0: |
| # Load initial bias from w into accumulators |
| VLDM r9!, {d16-d19} // Bias |
| VMOV q10, q8 |
| VMOV q11, q9 |
| LDR r7, [sp, 120] // zero |
| VMOV q12, q8 |
| VMOV q13, q9 |
| VMOV q14, q8 |
| VMOV q15, q9 |
| |
| 1: |
| # Load next 4 A pointers |
| LDR r3, [r2, 0] |
| LDR r12, [r2, 4] |
| LDR r10, [r2, 8] |
| LDR r0, [r2, 12] |
| ADD r2, r2, 16 |
| |
| # Add a_offset |
| LDR r5, [sp, 116] // a_offset |
| CMP r3, r7 // if a0 == zero |
| ADD r3, r3, r5 // a0 += a_offset |
| MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset |
| CMP r12, r7 // if a1 == zero |
| ADD r12, r12, r5 // a1 += a_offset |
| MOVEQ r12, r7 // a1 = zero, else += a1 + a_offset |
| CMP r10, r7 // if a2 == zero |
| ADD r10, r10, r5 // a2 += a_offset |
| MOVEQ r10, r7 // a2 = zero, else += a2 + a_offset |
| CMP r0, r7 // if a3 == zero |
| ADD r0, r0, r5 // a3 += a_offset |
| LDR r5, [sp, 52] // kc |
| MOVEQ r0, r7 // a3 = zero, else += a3 + a_offset |
| |
| SUBS r5, r5, 8 // kc - 8 |
| BLO 4f // less than 8 channels? |
| |
| # Main loop - 8 bytes of A. |
| # 16 SDOT, 4 LD64 A, 4 LD128 B |
| .p2align 3 |
| 2: |
| VLD1.8 {d0}, [r3]! // A0 |
| VLD1.8 {q2}, [r9]! // B0 |
| VLD1.8 {d1}, [r12]! // A1 |
| VLD1.8 {q3}, [r9]! // B1 |
| VLD1.8 {d2}, [r10]! // A2 |
| VLD1.8 {q4}, [r9]! // B2 |
| VLD1.8 {d3}, [r0]! // A3 |
| VLD1.8 {q5}, [r9]! // B3 |
| SUBS r5, r5, 8 |
| |
| VSDOT.S8 q8, q2, d0[0] |
| VSDOT.S8 q9, q3, d0[0] |
| VSDOT.S8 q10, q2, d1[0] |
| VSDOT.S8 q11, q3, d1[0] |
| VSDOT.S8 q12, q2, d2[0] |
| VSDOT.S8 q13, q3, d2[0] |
| VSDOT.S8 q14, q2, d3[0] |
| VSDOT.S8 q15, q3, d3[0] |
| |
| VSDOT.S8 q8, q4, d0[1] |
| VSDOT.S8 q9, q5, d0[1] |
| VSDOT.S8 q10, q4, d1[1] |
| VSDOT.S8 q11, q5, d1[1] |
| VSDOT.S8 q12, q4, d2[1] |
| VSDOT.S8 q13, q5, d2[1] |
| VSDOT.S8 q14, q4, d3[1] |
| VSDOT.S8 q15, q5, d3[1] |
| BHS 2b |
| |
| # Is there a remainder?- 4 bytes of A |
| TST r5, 4 |
| BNE 4f |
| |
| 3: |
| # ks loop |
| SUBS r14, r14, 16 // ks -= MR * sizeof(void*) |
| BHI 1b |
| |
| LDR r7, [sp, 112] // cn_stride |
| LDR r14, [sp, 56] // p = ks |
| |
| $if REQUANTIZATION == "RNDNU": |
| # RNDNU quantization |
| VDUP.32 q0, d12[0] // right_pre_shift |
| |
| VQSHL.S32 q8, q8, q0 |
| VQSHL.S32 q9, q9, q0 |
| VQSHL.S32 q10, q10, q0 |
| VQSHL.S32 q11, q11, q0 |
| VQSHL.S32 q12, q12, q0 |
| VQSHL.S32 q13, q13, q0 |
| VQSHL.S32 q14, q14, q0 |
| VQSHL.S32 q15, q15, q0 |
| |
| VDUP.32 q2, d13[0] // right_post_shift |
| |
| VQDMULH.S32 q8, q8, d12[1] // multiplier |
| VQDMULH.S32 q9, q9, d12[1] |
| VQDMULH.S32 q10, q10, d12[1] |
| VQDMULH.S32 q11, q11, d12[1] |
| VQDMULH.S32 q12, q12, d12[1] |
| VQDMULH.S32 q13, q13, d12[1] |
| VQDMULH.S32 q14, q14, d12[1] |
| VQDMULH.S32 q15, q15, d12[1] |
| |
| VRSHL.S32 q8, q8, q2 |
| VRSHL.S32 q9, q9, q2 |
| VRSHL.S32 q10, q10, q2 |
| VRSHL.S32 q11, q11, q2 |
| VRSHL.S32 q12, q12, q2 |
| VRSHL.S32 q13, q13, q2 |
| VRSHL.S32 q14, q14, q2 |
| VRSHL.S32 q15, q15, q2 |
| $else: |
| # QC8 FP32 quantization |
| VLD1.8 {q0-q1}, [r9]! |
| |
| VCVT.F32.S32 q8, q8 |
| VCVT.F32.S32 q9, q9 |
| VCVT.F32.S32 q10, q10 |
| VCVT.F32.S32 q11, q11 |
| VCVT.F32.S32 q12, q12 |
| VCVT.F32.S32 q13, q13 |
| VCVT.F32.S32 q14, q14 |
| VCVT.F32.S32 q15, q15 |
| |
| VMUL.F32 q8, q8, q0 // multiplier |
| VMUL.F32 q9, q9, q1 |
| VMUL.F32 q10, q10, q0 |
| VMUL.F32 q11, q11, q1 |
| VMUL.F32 q12, q12, q0 |
| VMUL.F32 q13, q13, q1 |
| VMUL.F32 q14, q14, q0 |
| VMUL.F32 q15, q15, q1 |
| |
| VCVTN.S32.F32 q8, q8 |
| VCVTN.S32.F32 q9, q9 |
| VCVTN.S32.F32 q10, q10 |
| VCVTN.S32.F32 q11, q11 |
| VCVTN.S32.F32 q12, q12 |
| VCVTN.S32.F32 q13, q13 |
| VCVTN.S32.F32 q14, q14 |
| VCVTN.S32.F32 q15, q15 |
| VDUP.16 q0, d13[2] // output_zero_point |
| |
| VQMOVN.S32 d16, q8 |
| VQMOVN.S32 d17, q9 |
| VQMOVN.S32 d18, q10 |
| VQMOVN.S32 d19, q11 |
| VQMOVN.S32 d20, q12 |
| VQMOVN.S32 d21, q13 |
| VQMOVN.S32 d22, q14 |
| VQMOVN.S32 d23, q15 |
| |
| VQADD.S16 q8, q8, q0 |
| VQADD.S16 q9, q9, q0 |
| VQADD.S16 q10, q10, q0 |
| VQADD.S16 q11, q11, q0 |
| |
| VDUP.8 q12, d13[6] // output_min |
| |
| VQMOVN.S16 d0, q8 |
| VQMOVN.S16 d1, q9 |
| VQMOVN.S16 d2, q10 |
| VQMOVN.S16 d3, q11 |
| |
| VDUP.8 q13, d13[7] // output_max |
| |
| VMAX.S8 q0, q0, q12 |
| VMAX.S8 q1, q1, q12 |
| |
| SUBS r1, r1, 8 // nc -= 8 |
| |
| VMIN.S8 q0, q0, q13 |
| VMIN.S8 q1, q1, q13 |
| |
| # Store full 4 x 8 |
| BLO 5f |
| VST1.8 {d3}, [r6], r7 |
| VST1.8 {d2}, [r8], r7 |
| VST1.8 {d1}, [r4], r7 |
| VST1.8 {d0}, [r11], r7 |
| SUB r2, r2, r14 // a -= ks |
| BHI 0b |
| |
| VPOP {d8-d13} |
| ADD sp, sp, 12 // skip pad, r2, r3 |
| POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} |
| |
| 4: |
| # Remainder- 4 bytes of A |
| VLD1.32 {d0[0]}, [r3]! // A0 |
| VLD1.32 {q2}, [r9]! // B0 |
| VLD1.32 {d1[0]}, [r12]! // A1 |
| VLD1.32 {q3}, [r9]! // B1 |
| VLD1.32 {d2[0]}, [r10]! // A2 |
| VLD1.32 {d3[0]}, [r0]! // A3 |
| |
| VSDOT.S8 q8, q2, d0[0] |
| VSDOT.S8 q9, q3, d0[0] |
| VSDOT.S8 q10, q2, d1[0] |
| VSDOT.S8 q11, q3, d1[0] |
| VSDOT.S8 q12, q2, d2[0] |
| VSDOT.S8 q13, q3, d2[0] |
| VSDOT.S8 q14, q2, d3[0] |
| VSDOT.S8 q15, q3, d3[0] |
| B 3b |
| |
| # Store odd width |
| .p2align 3 |
| 5: |
| TST r1, 4 |
| BEQ 6f |
| VST1.32 {d3[0]}, [r6]! |
| VST1.32 {d2[0]}, [r8]! |
| VST1.32 {d1[0]}, [r4]! |
| VST1.32 {d0[0]}, [r11]! |
| VEXT.8 q1, q1, q1, 4 |
| VEXT.8 q0, q0, q0, 4 |
| 6: |
| TST r1, 2 |
| BEQ 7f |
| VST1.16 {d3[0]}, [r6]! |
| VST1.16 {d2[0]}, [r8]! |
| VST1.16 {d1[0]}, [r4]! |
| VST1.16 {d0[0]}, [r11]! |
| VEXT.8 q1, q1, q1, 2 |
| VEXT.8 q0, q0, q0, 2 |
| |
| 7: |
| TST r1, 1 |
| BEQ 8f |
| VST1.8 {d3[0]}, [r6] |
| VST1.8 {d2[0]}, [r8] |
| VST1.8 {d1[0]}, [r4] |
| VST1.8 {d0[0]}, [r11] |
| |
| 8: |
| VPOP {d8-d13} |
| ADD sp, sp, 12 // skip pad, r2, r3 |
| POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} |
| |
| END_FUNCTION xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x8c4__aarch32_neondot_ld64 |
| # LINT.ThenChange(4x8c4-rndnu-aarch32-neondot-ld64.cc) |
| #endif // __APPLE__ |
| |
| #ifdef __ELF__ |
| .section ".note.GNU-stack","",%progbits |
| #endif |