| // Copyright (c) Facebook, Inc. and its affiliates. |
| // All rights reserved. |
| // |
| // Copyright 2019 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| #include <arm_neon.h> |
| |
| #include <xnnpack/zip.h> |
| |
| |
| void xnn_x8_zip_xm_ukernel__neon( |
| size_t n, |
| size_t m, |
| const uint8_t* input, |
| uint8_t* output) |
| { |
| const uint8_t* w = input; |
| const size_t input_increment = n * 3; |
| const size_t output_increment = 4 - m * n; |
| const uint8_t* last_input = w + n * (m - 1); |
| uint8_t* last_output = (uint8_t*) ((uintptr_t) output + (m - 4)); |
| |
| if (n >= 8) { |
| for (size_t i = 0; i < m; i += 4) { |
| size_t k = n; |
| w = (const uint8_t*) ((uintptr_t) w + input_increment); |
| if (w >= last_input) { |
| w = last_input; |
| } |
| const uint8_t* z = (const uint8_t*) ((uintptr_t) w - n); |
| const uint8_t* y = (const uint8_t*) ((uintptr_t) z - n); |
| const uint8_t* x = (const uint8_t*) ((uintptr_t) y - n); |
| while (k >= 8) { |
| const uint8x8_t vx = vld1_u8(x); x += 8; |
| const uint8x8_t vy = vld1_u8(y); y += 8; |
| const uint8x8_t vz = vld1_u8(z); z += 8; |
| const uint8x8_t vw = vld1_u8(w); w += 8; |
| |
| const uint8x8x2_t vxy = vzip_u8(vx, vy); |
| const uint8x8x2_t vzw = vzip_u8(vz, vw); |
| const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0])); |
| const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1])); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 1); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 1); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 1); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 1); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| k -= 8; |
| } |
| if (k != 0) { |
| const size_t address_increment = k - 8; |
| x = (const uint8_t*) ((uintptr_t) x + address_increment); |
| y = (const uint8_t*) ((uintptr_t) y + address_increment); |
| z = (const uint8_t*) ((uintptr_t) z + address_increment); |
| w = (const uint8_t*) ((uintptr_t) w + address_increment); |
| const int64x1_t vshift = vmov_n_s64(8 * address_increment); |
| |
| const uint64x1_t vx = vshl_u64(vreinterpret_u64_u8(vld1_u8(x)), vshift); |
| const uint64x1_t vy = vshl_u64(vreinterpret_u64_u8(vld1_u8(y)), vshift); |
| const uint64x1_t vz = vshl_u64(vreinterpret_u64_u8(vld1_u8(z)), vshift); |
| const uint64x1_t vw = vshl_u64(vreinterpret_u64_u8(vld1_u8(w)), vshift); w += 8; |
| const uint8x8x2_t vxy = vzip_u8(vreinterpret_u8_u64(vx), vreinterpret_u8_u64(vy)); |
| const uint8x8x2_t vzw = vzip_u8(vreinterpret_u8_u64(vz), vreinterpret_u8_u64(vw)); |
| const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0])); |
| const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1])); |
| |
| uint32x2_t vxyzw0 = vreinterpret_u32_u16(vxyzw_lo.val[0]); |
| uint32x2_t vxyzw1 = vreinterpret_u32_u16(vxyzw_lo.val[1]); |
| uint32x2_t vxyzw2 = vreinterpret_u32_u16(vxyzw_hi.val[0]); |
| uint32x2_t vxyzw3 = vreinterpret_u32_u16(vxyzw_hi.val[1]); |
| |
| if (k & 4) { |
| vst1_lane_u32((void*) output, vxyzw0, 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vxyzw0, 1); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vxyzw1, 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vxyzw1, 1); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vxyzw0 = vxyzw2; |
| vxyzw1 = vxyzw3; |
| } |
| |
| if (k & 2) { |
| vst1_lane_u32((void*) output, vxyzw0, 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vst1_lane_u32((void*) output, vxyzw0, 1); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| |
| vxyzw0 = vxyzw1; |
| } |
| if (k & 1) { |
| vst1_lane_u32((void*) output, vxyzw0, 0); |
| output = (uint8_t*) ((uintptr_t) output + m); |
| } |
| } |
| output = (uint8_t*) ((uintptr_t) output + output_increment); |
| if (output > last_output) { |
| output = last_output; |
| } |
| } |
| } else { |
| const uint8_t* i = input; |
| uint8_t* o = output; |
| size_t k = n; |
| do { |
| size_t l = m; |
| const uint8_t* ii = i++; |
| do { |
| *o++ = *ii; |
| ii += n; |
| } while (--l != 0); |
| } while (--k != 0); |
| } |
| } |