| // Copyright 2021 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| $assert SSE in [2, 4] |
| $assert not XOP or AVX |
| $assert not AVX or SSE == 4 |
| $assert DATATYPE in ["S8", "U8"] |
| $assert CHANNEL_TILE % 8 == 0 |
| $assert CHANNEL_TILE >= 8 |
| $assert PIXEL_TILE == 1 |
| $ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
| #include <assert.h> |
| |
| $if XOP: |
| #if defined(__GNUC__) || defined(__clang__) |
| #include <x86intrin.h> |
| #else |
| #include <immintrin.h> |
| #include <ammintrin.h> |
| #endif |
| $else: |
| $SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE] |
| #include <${SSE_HEADER}> |
| |
| #include <xnnpack/common.h> |
| #include <xnnpack/ibilinear.h> |
| #include <xnnpack/unaligned.h> |
| |
| |
| $XINT8_T = {"S8": "int8_t", "U8": "uint8_t"}[DATATYPE] |
| $_MM_CVTEPX8_EPI16 = {"S8": "_mm_cvtepi8_epi16", "U8": "_mm_cvtepu8_epi16"}[DATATYPE] |
| $_MM_SRXI_EPI32 = {"S8": "_mm_srai_epi32", "U8": "_mm_srli_epi32"}[DATATYPE] |
| $_MM_SRXI_EPI16 = {"S8": "_mm_srai_epi16", "U8": "_mm_srli_epi16"}[DATATYPE] |
| $_MM_PACKXS_EPI16 = {"S8": "_mm_packs_epi16", "U8": "_mm_packus_epi16"}[DATATYPE] |
| $ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE] |
| void xnn_${DATATYPE.lower()}_ibilinear_ukernel__${ISA}_c${CHANNEL_TILE}${"" if PIXEL_TILE == 1 else "x%d" % PIXEL_TILE}( |
| size_t output_pixels, |
| size_t channels, |
| const ${XINT8_T}**restrict input, |
| size_t input_offset, |
| const int16_t*restrict weights, |
| ${XINT8_T}*restrict output, |
| size_t output_increment) XNN_OOB_READS |
| { |
| assert(output_pixels != 0); |
| assert(channels != 0); |
| |
| do { |
| const ${XINT8_T}* i0 = (const ${XINT8_T}*) ((uintptr_t) input[0] + input_offset); |
| const ${XINT8_T}* i1 = (const ${XINT8_T}*) ((uintptr_t) input[1] + input_offset); |
| const ${XINT8_T}* i2 = (const ${XINT8_T}*) ((uintptr_t) input[2] + input_offset); |
| const ${XINT8_T}* i3 = (const ${XINT8_T}*) ((uintptr_t) input[3] + input_offset); |
| input += 4; |
| |
| const __m128i valpha = _mm_cvtsi32_si128(*((const int*) weights)); |
| weights += 2; |
| __m128i valphah = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(0, 0, 0, 0)); |
| valphah = _mm_unpacklo_epi64(valphah, valphah); |
| $if SSE == 2: |
| __m128i valphav = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(1, 1, 1, 1)); |
| valphav = _mm_unpacklo_epi64(valphav, valphav); |
| $else: |
| __m128i valphav = _mm_srli_epi32(valpha, 16); |
| valphav = _mm_shuffle_epi32(valphav, _MM_SHUFFLE(0, 0, 0, 0)); |
| |
| $if SSE == 4: |
| valphah = _mm_blend_epi16(valphah, _mm_sub_epi16(_mm_set1_epi32(0x08000000), valphah), 0xAA); |
| $else: |
| valphah = _mm_xor_si128(valphah, _mm_set1_epi32(0xFFFF0000)); |
| valphah = _mm_add_epi16(valphah, _mm_set1_epi32(0x08010000)); |
| |
| const __m128i vrounding = _mm_set1_epi32(0x00200000); |
| |
| size_t c = channels; |
| $if CHANNEL_TILE > 8: |
| for (; c >= ${CHANNEL_TILE} * sizeof(${XINT8_T}); c -= ${CHANNEL_TILE} * sizeof(${XINT8_T})) { |
| $if SSE == 4: |
| const __m128i vtl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0)); |
| const __m128i vtr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1)); |
| const __m128i vbl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2)); |
| const __m128i vbr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3)); |
| $for C in range(8, CHANNEL_TILE, 8): |
| const __m128i vtl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i0 + ${C}))); |
| const __m128i vtr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i1 + ${C}))); |
| const __m128i vbl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i2 + ${C}))); |
| const __m128i vbr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i3 + ${C}))); |
| $else: |
| __m128i vtl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i0); |
| __m128i vtr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i1); |
| __m128i vbl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i2); |
| __m128i vbr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i3); |
| $for C in range(8, CHANNEL_TILE, 8): |
| __m128i vtl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i0 + ${C})); |
| __m128i vtr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i1 + ${C})); |
| __m128i vbl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i2 + ${C})); |
| __m128i vbr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i3 + ${C})); |
| i0 += ${CHANNEL_TILE}; |
| i1 += ${CHANNEL_TILE}; |
| i2 += ${CHANNEL_TILE}; |
| i3 += ${CHANNEL_TILE}; |
| |
| $if SSE != 4: |
| $if DATATYPE == "U8": |
| __m128i vzero = _mm_setzero_si128(); |
| $for C in range(0, CHANNEL_TILE, 8): |
| vtl${ABC[C:C+8]} = _mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vzero); |
| vtr${ABC[C:C+8]} = _mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vzero); |
| vbl${ABC[C:C+8]} = _mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vzero); |
| vbr${ABC[C:C+8]} = _mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vzero); |
| $else: |
| $for C in range(0, CHANNEL_TILE, 8): |
| vtl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vtl${ABC[C:C+8]}), 8); |
| vtr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vtr${ABC[C:C+8]}), 8); |
| vbl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vbl${ABC[C:C+8]}), 8); |
| vbr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vbr${ABC[C:C+8]}), 8); |
| |
| $for C in range(0, CHANNEL_TILE, 8): |
| const __m128i vdr${ABC[C:C+8]} = _mm_sub_epi16(vbr${ABC[C:C+8]}, vtr${ABC[C:C+8]}); |
| const __m128i vt${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah); |
| const __m128i vdl${ABC[C:C+8]} = _mm_sub_epi16(vbl${ABC[C:C+8]}, vtl${ABC[C:C+8]}); |
| const __m128i vt${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah); |
| |
| $for C in range(0, CHANNEL_TILE, 8): |
| const __m128i vd${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah); |
| const __m128i vd${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah); |
| |
| $if SSE == 4: |
| $for C in range(0, CHANNEL_TILE, 4): |
| __m128i vacc${ABC[C:C+4]} = _mm_mullo_epi32(vd${ABC[C:C+4]}, valphav); |
| $else: |
| $for C in range(0, CHANNEL_TILE, 4): |
| __m128i vacc${ABC[C:C+4]} = _mm_slli_epi32(_mm_mulhi_epu16(vd${ABC[C:C+4]}, valphav), 16); |
| |
| $for C in range(0, CHANNEL_TILE, 4): |
| vacc${ABC[C:C+4]} = _mm_add_epi16(_mm_mullo_epi16(vd${ABC[C:C+4]}, valphav), vacc${ABC[C:C+4]}); |
| |
| $for C in range(0, CHANNEL_TILE, 4): |
| vacc${ABC[C:C+4]} = _mm_add_epi32(_mm_slli_epi32(vt${ABC[C:C+4]}, 11), vacc${ABC[C:C+4]}); |
| |
| $for C in range(0, CHANNEL_TILE, 4): |
| vacc${ABC[C:C+4]} = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc${ABC[C:C+4]}, vrounding), 22); |
| |
| $for C in range(0, CHANNEL_TILE, 8): |
| const __m128i vacc${ABC[C:C+8]} = _mm_packs_epi32(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]}); |
| |
| $for C in range(0, CHANNEL_TILE, 16): |
| $if C + 8 < CHANNEL_TILE: |
| const __m128i vo${ABC[C:C+16]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}); |
| $else: |
| const __m128i vo${ABC[C:C+8]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C:C+8]}); |
| |
| _mm_storeu_si128((__m128i*) output, vo${ABC[0:16]}); |
| $for C in range(16, CHANNEL_TILE, 16): |
| $if C + 8 < CHANNEL_TILE: |
| _mm_storeu_si128((__m128i*) (output + ${C}), vo${ABC[C:C+16]}); |
| $else: |
| _mm_storel_epi64((__m128i*) (output + ${C}), vo${ABC[C:C+8]}); |
| output += ${CHANNEL_TILE}; |
| } |
| for (; c >= 8 * sizeof(${XINT8_T}); c -= 8 * sizeof(${XINT8_T})) { |
| $if SSE == 4: |
| const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0)); |
| i0 += 8; |
| const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1)); |
| i1 += 8; |
| const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2)); |
| i2 += 8; |
| const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3)); |
| i3 += 8; |
| $else: |
| __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0); |
| i0 += 8; |
| __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1); |
| i1 += 8; |
| __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2); |
| i2 += 8; |
| __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3); |
| i3 += 8; |
| |
| $if SSE != 4: |
| $if DATATYPE == "U8": |
| __m128i vzero = _mm_setzero_si128(); |
| vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero); |
| vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero); |
| vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero); |
| vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero); |
| $else: |
| vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8); |
| vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8); |
| vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8); |
| vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8); |
| |
| const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567); |
| const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah); |
| const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567); |
| const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah); |
| |
| const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah); |
| const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah); |
| |
| $if SSE == 4: |
| __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav); |
| __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav); |
| $else: |
| __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16); |
| __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16); |
| |
| vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123); |
| vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567); |
| |
| vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123); |
| vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567); |
| |
| vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22); |
| vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22); |
| |
| const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567); |
| |
| const __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567); |
| |
| _mm_storel_epi64((__m128i*) output, vo01234567); |
| output += 8; |
| } |
| if XNN_UNLIKELY(c != 0) { |
| $if SSE == 4: |
| const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0)); |
| const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1)); |
| const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2)); |
| const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3)); |
| $else: |
| __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0); |
| __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1); |
| __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2); |
| __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3); |
| |
| $if SSE != 4: |
| $if DATATYPE == "U8": |
| __m128i vzero = _mm_setzero_si128(); |
| vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero); |
| vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero); |
| vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero); |
| vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero); |
| $else: |
| vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8); |
| vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8); |
| vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8); |
| vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8); |
| |
| const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567); |
| const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah); |
| const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567); |
| const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah); |
| |
| const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah); |
| const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah); |
| |
| $if SSE == 4: |
| __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav); |
| __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav); |
| $else: |
| __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16); |
| __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16); |
| |
| vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123); |
| vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567); |
| |
| vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123); |
| vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567); |
| |
| vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22); |
| vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22); |
| |
| const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567); |
| |
| __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567); |
| |
| if (c & (4 * sizeof(${XINT8_T}))) { |
| unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vo01234567)); |
| output += 4; |
| vo01234567 = _mm_srli_epi64(vo01234567, 32); |
| } |
| $if SSE == 4: |
| if (c & (2 * sizeof(${XINT8_T}))) { |
| unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vo01234567, 0)); |
| output += 2; |
| vo01234567 = _mm_srli_epi32(vo01234567, 16); |
| } |
| if (c & (1 * sizeof(${XINT8_T}))) { |
| *output++ = (uint8_t) _mm_extract_epi8(vo01234567, 0); |
| } |
| $else: |
| uint32_t vo0123 = (uint32_t) _mm_cvtsi128_si32(vo01234567); |
| if (c & (2 * sizeof(${XINT8_T}))) { |
| unaligned_store_u16(output, (uint16_t) vo0123); |
| output += 2; |
| vo0123 >>= 16; |
| } |
| if (c & (1 * sizeof(${XINT8_T}))) { |
| *output++ = (uint8_t) vo0123; |
| } |
| } |
| |
| output = (${XINT8_T}*) ((uintptr_t) output + output_increment); |
| } while (--output_pixels != 0); |
| } |