| // Copyright 2022 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| $assert BATCH_TILE % 8 == 0 |
| $assert BATCH_TILE >= 8 |
| $SIMD_TILE = BATCH_TILE // 8 |
| #include <assert.h> |
| #include <stddef.h> |
| #include <stdint.h> |
| |
| #include <arm_neon.h> |
| |
| #include <xnnpack/math.h> |
| #include <xnnpack/window.h> |
| |
| $SHIFT_VARIANT = "_shift%s" % SHIFT if SHIFT else "" |
| |
| void xnn_s16_window${SHIFT_VARIANT}_ukernel__neon_x${BATCH_TILE}( |
| size_t rows, |
| size_t batch_size, |
| const int16_t* input, |
| const int16_t* weights, |
| int16_t* output, |
| uint32_t shift) |
| { |
| assert(rows != 0); |
| assert(batch_size != 0); |
| assert(input != NULL); |
| assert(weights != NULL); |
| assert(output != NULL); |
| $if SHIFT != 0: |
| assert(shift == ${SHIFT}); |
| $else: |
| assert(shift < 32); |
| |
| $if SHIFT == 0: |
| const int32x4_t vshift = vdupq_n_s32(-(int32_t)shift); // negative to shift right. |
| |
| do { |
| const int16_t* w = weights; |
| size_t n = batch_size * sizeof(int16_t); |
| $if BATCH_TILE > 8: |
| for (; n >= ${BATCH_TILE} * sizeof(int16_t); n -= ${BATCH_TILE} * sizeof(int16_t)) { |
| $for N in range(SIMD_TILE): |
| const int16x8_t vi${N} = vld1q_s16(input); input += 8; |
| |
| $for N in range(SIMD_TILE): |
| const int16x8_t vw${N} = vld1q_s16(w); w += 8; |
| |
| $if SHIFT == 15: |
| $for N in range(SIMD_TILE): |
| const int16x8_t vout${N} = vqdmulhq_s16(vi${N}, vw${N}); |
| $else: |
| $for N in range(SIMD_TILE): |
| int32x4_t vacc${N}_lo = vmull_s16(vget_low_s16(vi${N}), vget_low_s16(vw${N})); |
| int32x4_t vacc${N}_hi = vmull_s16(vget_high_s16(vi${N}), vget_high_s16(vw${N})); |
| |
| $if SHIFT != 0: |
| $for N in range(SIMD_TILE): |
| const int16x4_t vshift${N}_lo = vqshrn_n_s32(vacc${N}_lo, ${SHIFT}); |
| const int16x4_t vshift${N}_hi = vqshrn_n_s32(vacc${N}_hi, ${SHIFT}); |
| |
| $for N in range(SIMD_TILE): |
| const int16x8_t vout${N} = vcombine_s16(vshift${N}_lo, vshift${N}_hi); |
| $else: |
| $for N in range(SIMD_TILE): |
| vacc${N}_lo = vshlq_s32(vacc${N}_lo, vshift); |
| vacc${N}_hi = vshlq_s32(vacc${N}_hi, vshift); |
| |
| $for N in range(SIMD_TILE): |
| const int16x8_t vout${N} = vcombine_s16(vqmovn_s32(vacc${N}_lo), vqmovn_s32(vacc${N}_hi)); |
| |
| $for N in range(SIMD_TILE): |
| vst1q_s16(output, vout${N}); output += 8; |
| } |
| |
| // Remainder of full vectors |
| for (; n >= 8 * sizeof(int16_t); n -= 8 * sizeof(int16_t)) { |
| const int16x8_t vi = vld1q_s16(input); input += 8; |
| const int16x8_t vw = vld1q_s16(w); w += 8; |
| $if SHIFT == 15: |
| const int16x8_t vout = vqdmulhq_s16(vi, vw); |
| $else: |
| int32x4_t vacc_lo = vmull_s16(vget_low_s16(vi), vget_low_s16(vw)); |
| int32x4_t vacc_hi = vmull_s16(vget_high_s16(vi), vget_high_s16(vw)); |
| $if SHIFT != 0: |
| const int16x4_t vshift_lo = vqshrn_n_s32(vacc_lo, ${SHIFT}); |
| const int16x4_t vshift_hi = vqshrn_n_s32(vacc_hi, ${SHIFT}); |
| const int16x8_t vout = vcombine_s16(vshift_lo, vshift_hi); |
| $else: |
| vacc_lo = vshlq_s32(vacc_lo, vshift); |
| vacc_hi = vshlq_s32(vacc_hi, vshift); |
| const int16x8_t vout = vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)); |
| vst1q_s16(output, vout); output += 8; |
| } |
| |
| assert(n % 2 == 0); |
| // Remainder of 1 to 7 batch_size |
| if XNN_UNLIKELY(n != 0) { |
| const int16x8_t vi = vld1q_s16(input); input = (const int16_t*) ((uintptr_t) input + n); |
| const int16x8_t vw = vld1q_s16(w); |
| $if SHIFT == 15: |
| int16x4_t vout = vqdmulh_s16(vget_low_s16(vi), vget_low_s16(vw)); |
| $else: |
| int32x4_t vacc = vmull_s16(vget_low_s16(vi), vget_low_s16(vw)); |
| $if SHIFT != 0: |
| int16x4_t vout = vqshrn_n_s32(vacc, ${SHIFT}); |
| $else: |
| vacc = vshlq_s32(vacc, vshift); |
| int16x4_t vout = vqmovn_s32(vacc); |
| if (n & (4 * sizeof(int16_t))) { |
| vst1_s16(output, vout); output += 4; |
| $if SHIFT == 15: |
| vout = vqdmulh_s16(vget_high_s16(vi), vget_high_s16(vw)); |
| $else: |
| vacc = vmull_s16(vget_high_s16(vi), vget_high_s16(vw)); |
| $if SHIFT != 0: |
| vout = vqshrn_n_s32(vacc, ${SHIFT}); |
| $else: |
| vacc = vshlq_s32(vacc, vshift); |
| vout = vqmovn_s32(vacc); |
| } |
| if (n & (2 * sizeof(int16_t))) { |
| vst1_lane_u32((void*) output, vreinterpret_u32_s16(vout), 0); output += 2; |
| vout = vext_s16(vout, vout, 2); |
| } |
| if (n & (1 * sizeof(int16_t))) { |
| vst1_lane_s16(output, vout, 0); output += 1; |
| } |
| } |
| |
| } while (--rows != 0); |
| } |