blob: b302f8d18438405e29aa9d8ad61d97efaed65fae [file] [log] [blame] [edit]
// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <xnnpack/assembly.h>
# void xnn_f32_igemm_minmax_ukernel_4x2__aarch64_neonfma_ld64(
# size_t mr, x0
# size_t nc, x1
# size_t kc, x2 / x0
# size_t ks, x3 / x9
# const float**restrict a, x4
# const float*restrict w, x5
# float*restrict c, x6
# size_t cm_stride, x7
# size_t cn_stride, [sp] -> x10
# size_t a_offset, [sp + 8] -> x11
# const float* zero, [sp + 16] -> x12
# const xnn_f32_minmax_params params [sp + 24] -> (x8)
# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
# A pointers
# x8 a0
# x13 a1
# x14 a2
# x15 a3
# C pointers
# x6 c0
# x16 c1
# x17 c2
# x7 c3 / cm_stride
# Vector register usage
# A0 v0
# A1 v1
# A2 v2
# A3 v3
# B v20 v21
# C v24 v25
# C v26 v27
# C v28 v29
# C v30 v31
# Clamp v4 v5
BEGIN_FUNCTION xnn_f32_igemm_minmax_ukernel_4x2__aarch64_neonfma_ld64
# Load cn_stride, a_offset
LDP x10, x11, [sp]
# Load zero, params pointer
LDP x12, x8, [sp, 16]
# Clamp C pointers
CMP x0, 2 // if mr < 2
ADD x16, x6, x7 // c1 = c0 + cm_stride
CSEL x16, x6, x16, LO // c1 = c0
# Load min/max values
LD2R {v4.2s, v5.2s}, [x8]
ADD x17, x16, x7 // c2 = c1 + cm_stride
// if mr <= 2
CSEL x17, x16, x17, LS // c2 = c1
CMP x0, 4 // if mr < 4
ADD x7, x17, x7 // c3 = c2 + cm_stride
CSEL x7, x17, x7, LO // c3 = c2
0:
# Load initial bias from w into accumulators
LDR d24, [x5], 8
MOV v26.8b, v24.8b
MOV v28.8b, v24.8b
MOV v30.8b, v24.8b
MOVI v25.2s, 0
MOVI v27.2s, 0
MOVI v29.2s, 0
MOVI v31.2s, 0
MOV x9, x3 // p = ks
1:
# Load next 4 A pointers
LDP x8, x13, [x4], 16
LDP x14, x15, [x4], 16
CMP x8, x12 // if a0 == zero
ADD x8, x8, x11 // a0 += a_offset
CSEL x8, x12, x8, EQ // a0 = zero, else += a0 + a_offset
CMP x13, x12 // if a1 == zero
ADD x13, x13, x11 // a1 += a_offset
CSEL x13, x12, x13, EQ // a1 = zero, else += a1 + a_offset
CMP x14, x12 // if a2 == zero
ADD x14, x14, x11 // a2 += a_offset
CSEL x14, x12, x14, EQ // a2 = zero, else += a2 + a_offset
CMP x15, x12 // if a3 == zero
ADD x15, x15, x11 // a3 += a_offset
CSEL x15, x12, x15, EQ // a3 = zero, else += a3 + a_offset
# Is there at least 2 floats (8 bytes)?
SUBS x0, x2, 8 // k = kc - 8
B.LO 4f
# Main loop - 2 floats of A (8 bytes)
2:
LDR d0, [x8], 8
LDP d20, d21, [x5], 16
LDR d1, [x13], 8
LDR d2, [x14], 8
LDR d3, [x15], 8
SUBS x0, x0, 8
FMLA v24.2s, v20.2s, v0.s[0]
FMLA v26.2s, v20.2s, v1.s[0]
FMLA v28.2s, v20.2s, v2.s[0]
FMLA v30.2s, v20.2s, v3.s[0]
FMLA v25.2s, v21.2s, v0.s[1]
FMLA v27.2s, v21.2s, v1.s[1]
FMLA v29.2s, v21.2s, v2.s[1]
FMLA v31.2s, v21.2s, v3.s[1]
B.HS 2b
# Is there a remainder?- 1 float of A (4 bytes)
TBNZ x0, 2, 4f
3:
# ks loop
SUBS x9, x9, 32 // ks -= MR * sizeof(void*)
B.HI 1b
FADD v24.2s, v24.2s, v25.2s
FADD v26.2s, v26.2s, v27.2s
FADD v28.2s, v28.2s, v29.2s
FADD v30.2s, v30.2s, v31.2s
# Clamp
FMAX v24.2s, v24.2s, v4.2s
SUBS x1, x1, 2
FMAX v26.2s, v26.2s, v4.2s
FMAX v28.2s, v28.2s, v4.2s
FMAX v30.2s, v30.2s, v4.2s
FMIN v24.2s, v24.2s, v5.2s
FMIN v26.2s, v26.2s, v5.2s
FMIN v28.2s, v28.2s, v5.2s
FMIN v30.2s, v30.2s, v5.2s
# Store full 4 x 2
B.LO 5f
STR d30, [x7]
ADD x7, x7, x10
STR d28, [x17]
ADD x17, x17, x10
STR d26, [x16]
ADD x16, x16, x10
STR d24, [x6]
ADD x6, x6, x10
SUB x4, x4, x3 // a -= ks
# nc loop
B.HI 0b
RET
# Remainder- 1 float of A
4:
LDR s0, [x8], 4
LDR d20, [x5], 8
LDR s1, [x13], 4
LDR s2, [x14], 4
LDR s3, [x15], 4
FMLA v24.2s, v20.2s, v0.s[0]
FMLA v26.2s, v20.2s, v1.s[0]
FMLA v28.2s, v20.2s, v2.s[0]
FMLA v30.2s, v20.2s, v3.s[0]
B 3b
# Store odd width
5:
STR s30, [x7]
STR s28, [x17]
STR s26, [x16]
STR s24, [x6]
RET
END_FUNCTION xnn_f32_igemm_minmax_ukernel_4x2__aarch64_neonfma_ld64
#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif