blob: 8a7bc1464bc534fbfc9b21373d29c3ad0f4cf2e1 [file] [log] [blame] [edit]
// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert SSE in [2, 4]
$assert not XOP or AVX
$assert not AVX or SSE == 4
$assert DATATYPE in ["S8", "U8"]
$assert CHANNEL_TILE % 8 == 0
$assert CHANNEL_TILE >= 8
$assert PIXEL_TILE == 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>
$if XOP:
#if defined(__GNUC__) || defined(__clang__)
#include <x86intrin.h>
#else
#include <immintrin.h>
#include <ammintrin.h>
#endif
$else:
$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
#include <${SSE_HEADER}>
#include <xnnpack/common.h>
#include <xnnpack/ibilinear.h>
#include <xnnpack/unaligned.h>
$XINT8_T = {"S8": "int8_t", "U8": "uint8_t"}[DATATYPE]
$_MM_CVTEPX8_EPI16 = {"S8": "_mm_cvtepi8_epi16", "U8": "_mm_cvtepu8_epi16"}[DATATYPE]
$_MM_SRXI_EPI32 = {"S8": "_mm_srai_epi32", "U8": "_mm_srli_epi32"}[DATATYPE]
$_MM_SRXI_EPI16 = {"S8": "_mm_srai_epi16", "U8": "_mm_srli_epi16"}[DATATYPE]
$_MM_PACKXS_EPI16 = {"S8": "_mm_packs_epi16", "U8": "_mm_packus_epi16"}[DATATYPE]
$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE]
void xnn_${DATATYPE.lower()}_ibilinear_ukernel__${ISA}_c${CHANNEL_TILE}${"" if PIXEL_TILE == 1 else "x%d" % PIXEL_TILE}(
size_t output_pixels,
size_t channels,
const ${XINT8_T}**restrict input,
size_t input_offset,
const int16_t*restrict weights,
${XINT8_T}*restrict output,
size_t output_increment) XNN_OOB_READS
{
assert(output_pixels != 0);
assert(channels != 0);
do {
const ${XINT8_T}* i0 = (const ${XINT8_T}*) ((uintptr_t) input[0] + input_offset);
const ${XINT8_T}* i1 = (const ${XINT8_T}*) ((uintptr_t) input[1] + input_offset);
const ${XINT8_T}* i2 = (const ${XINT8_T}*) ((uintptr_t) input[2] + input_offset);
const ${XINT8_T}* i3 = (const ${XINT8_T}*) ((uintptr_t) input[3] + input_offset);
input += 4;
const __m128i valpha = _mm_cvtsi32_si128(*((const int*) weights));
weights += 2;
__m128i valphah = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(0, 0, 0, 0));
valphah = _mm_unpacklo_epi64(valphah, valphah);
$if SSE == 2:
__m128i valphav = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(1, 1, 1, 1));
valphav = _mm_unpacklo_epi64(valphav, valphav);
$else:
__m128i valphav = _mm_srli_epi32(valpha, 16);
valphav = _mm_shuffle_epi32(valphav, _MM_SHUFFLE(0, 0, 0, 0));
$if SSE == 4:
valphah = _mm_blend_epi16(valphah, _mm_sub_epi16(_mm_set1_epi32(0x08000000), valphah), 0xAA);
$else:
valphah = _mm_xor_si128(valphah, _mm_set1_epi32(0xFFFF0000));
valphah = _mm_add_epi16(valphah, _mm_set1_epi32(0x08010000));
const __m128i vrounding = _mm_set1_epi32(0x00200000);
size_t c = channels;
$if CHANNEL_TILE > 8:
for (; c >= ${CHANNEL_TILE} * sizeof(${XINT8_T}); c -= ${CHANNEL_TILE} * sizeof(${XINT8_T})) {
$if SSE == 4:
const __m128i vtl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
const __m128i vtr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
const __m128i vbl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
const __m128i vbr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
$for C in range(8, CHANNEL_TILE, 8):
const __m128i vtl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i0 + ${C})));
const __m128i vtr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i1 + ${C})));
const __m128i vbl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i2 + ${C})));
const __m128i vbr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i3 + ${C})));
$else:
__m128i vtl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i0);
__m128i vtr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i1);
__m128i vbl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i2);
__m128i vbr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i3);
$for C in range(8, CHANNEL_TILE, 8):
__m128i vtl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i0 + ${C}));
__m128i vtr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i1 + ${C}));
__m128i vbl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i2 + ${C}));
__m128i vbr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i3 + ${C}));
i0 += ${CHANNEL_TILE};
i1 += ${CHANNEL_TILE};
i2 += ${CHANNEL_TILE};
i3 += ${CHANNEL_TILE};
$if SSE != 4:
$if DATATYPE == "U8":
__m128i vzero = _mm_setzero_si128();
$for C in range(0, CHANNEL_TILE, 8):
vtl${ABC[C:C+8]} = _mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vzero);
vtr${ABC[C:C+8]} = _mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vzero);
vbl${ABC[C:C+8]} = _mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vzero);
vbr${ABC[C:C+8]} = _mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vzero);
$else:
$for C in range(0, CHANNEL_TILE, 8):
vtl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vtl${ABC[C:C+8]}), 8);
vtr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vtr${ABC[C:C+8]}), 8);
vbl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vbl${ABC[C:C+8]}), 8);
vbr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vbr${ABC[C:C+8]}), 8);
$for C in range(0, CHANNEL_TILE, 8):
const __m128i vdr${ABC[C:C+8]} = _mm_sub_epi16(vbr${ABC[C:C+8]}, vtr${ABC[C:C+8]});
const __m128i vt${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah);
const __m128i vdl${ABC[C:C+8]} = _mm_sub_epi16(vbl${ABC[C:C+8]}, vtl${ABC[C:C+8]});
const __m128i vt${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah);
$for C in range(0, CHANNEL_TILE, 8):
const __m128i vd${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah);
const __m128i vd${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah);
$if SSE == 4:
$for C in range(0, CHANNEL_TILE, 4):
__m128i vacc${ABC[C:C+4]} = _mm_mullo_epi32(vd${ABC[C:C+4]}, valphav);
$else:
$for C in range(0, CHANNEL_TILE, 4):
__m128i vacc${ABC[C:C+4]} = _mm_slli_epi32(_mm_mulhi_epu16(vd${ABC[C:C+4]}, valphav), 16);
$for C in range(0, CHANNEL_TILE, 4):
vacc${ABC[C:C+4]} = _mm_add_epi16(_mm_mullo_epi16(vd${ABC[C:C+4]}, valphav), vacc${ABC[C:C+4]});
$for C in range(0, CHANNEL_TILE, 4):
vacc${ABC[C:C+4]} = _mm_add_epi32(_mm_slli_epi32(vt${ABC[C:C+4]}, 11), vacc${ABC[C:C+4]});
$for C in range(0, CHANNEL_TILE, 4):
vacc${ABC[C:C+4]} = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc${ABC[C:C+4]}, vrounding), 22);
$for C in range(0, CHANNEL_TILE, 8):
const __m128i vacc${ABC[C:C+8]} = _mm_packs_epi32(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]});
$for C in range(0, CHANNEL_TILE, 16):
$if C + 8 < CHANNEL_TILE:
const __m128i vo${ABC[C:C+16]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]});
$else:
const __m128i vo${ABC[C:C+8]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C:C+8]});
_mm_storeu_si128((__m128i*) output, vo${ABC[0:16]});
$for C in range(16, CHANNEL_TILE, 16):
$if C + 8 < CHANNEL_TILE:
_mm_storeu_si128((__m128i*) (output + ${C}), vo${ABC[C:C+16]});
$else:
_mm_storel_epi64((__m128i*) (output + ${C}), vo${ABC[C:C+8]});
output += ${CHANNEL_TILE};
}
for (; c >= 8 * sizeof(${XINT8_T}); c -= 8 * sizeof(${XINT8_T})) {
$if SSE == 4:
const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
i0 += 8;
const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
i1 += 8;
const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
i2 += 8;
const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
i3 += 8;
$else:
__m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0);
i0 += 8;
__m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1);
i1 += 8;
__m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2);
i2 += 8;
__m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3);
i3 += 8;
$if SSE != 4:
$if DATATYPE == "U8":
__m128i vzero = _mm_setzero_si128();
vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero);
vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero);
vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero);
vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero);
$else:
vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8);
vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8);
vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8);
vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8);
const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567);
const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah);
const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567);
const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah);
const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah);
const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah);
$if SSE == 4:
__m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav);
__m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav);
$else:
__m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16);
__m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16);
vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123);
vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567);
vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123);
vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567);
vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22);
vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22);
const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567);
const __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567);
_mm_storel_epi64((__m128i*) output, vo01234567);
output += 8;
}
if XNN_UNLIKELY(c != 0) {
$if SSE == 4:
const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
$else:
__m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0);
__m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1);
__m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2);
__m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3);
$if SSE != 4:
$if DATATYPE == "U8":
__m128i vzero = _mm_setzero_si128();
vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero);
vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero);
vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero);
vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero);
$else:
vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8);
vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8);
vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8);
vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8);
const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567);
const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah);
const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567);
const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah);
const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah);
const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah);
$if SSE == 4:
__m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav);
__m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav);
$else:
__m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16);
__m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16);
vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123);
vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567);
vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123);
vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567);
vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22);
vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22);
const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567);
__m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567);
if (c & (4 * sizeof(${XINT8_T}))) {
unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vo01234567));
output += 4;
vo01234567 = _mm_srli_epi64(vo01234567, 32);
}
$if SSE == 4:
if (c & (2 * sizeof(${XINT8_T}))) {
unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vo01234567, 0));
output += 2;
vo01234567 = _mm_srli_epi32(vo01234567, 16);
}
if (c & (1 * sizeof(${XINT8_T}))) {
*output++ = (uint8_t) _mm_extract_epi8(vo01234567, 0);
}
$else:
uint32_t vo0123 = (uint32_t) _mm_cvtsi128_si32(vo01234567);
if (c & (2 * sizeof(${XINT8_T}))) {
unaligned_store_u16(output, (uint16_t) vo0123);
output += 2;
vo0123 >>= 16;
}
if (c & (1 * sizeof(${XINT8_T}))) {
*output++ = (uint8_t) vo0123;
}
}
output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
} while (--output_pixels != 0);
}