blob: 7de01ff55aefeffabe44099c008cdf5b1c3ff204 [file] [log] [blame]
Miao Wang7d0d5a62018-02-23 11:33:20 -08001// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// simd_wrappers_msa.h: MSA specialization of simd_wrappers.h
16
17#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
18#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
19
20#include <msa.h>
21
22namespace gemmlowp {
23
24using Int32x4 = v4i32;
25using Int16x8 = v8i16;
26using Uint8x16 = v16i8;
27
28template <int ScalarCount>
29struct RegisterType<std::int32_t, ScalarCount> {
30 using Type =
31 typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
32};
33
34template <int ScalarCount>
35struct RegisterType<std::int16_t, ScalarCount> {
Miao Wang70ba50c2019-08-08 12:30:36 -070036 using Type = typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
Miao Wang7d0d5a62018-02-23 11:33:20 -080037};
38
39template <int ScalarCount>
40struct RegisterType<std::uint8_t, ScalarCount> {
41 using Type = typename std::conditional<
42 ScalarCount >= 16, Uint8x16,
43 typename std::conditional<ScalarCount >= 4, std::uint32_t,
44 std::uint8_t>::type>::type;
45};
46
47inline Int32x4 LoadInt32x4(const std::int32_t* src) {
48 return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0);
49}
50
51inline Int32x4 LoadInt32x4(const Int32x4* src) {
52 return __builtin_msa_ld_w(const_cast<Int32x4*>(src), 0);
53}
54
55inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
56 __builtin_msa_st_w(value, dst, 0);
57}
58
59inline void StoreInt32x4(Int32x4* dst, Int32x4 value) {
60 __builtin_msa_st_w(value, dst, 0);
61}
62
63inline Int16x8 LoadInt16x8(const std::int16_t* src) {
64 return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0);
65}
66
67inline Int16x8 LoadInt16x8(const Int16x8* src) {
68 return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0);
69}
70
Miao Wang70ba50c2019-08-08 12:30:36 -070071inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
Miao Wang7d0d5a62018-02-23 11:33:20 -080072
Miao Wang70ba50c2019-08-08 12:30:36 -070073inline void StoreInt16x8(Int16x8* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
Miao Wang7d0d5a62018-02-23 11:33:20 -080074
75inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
76 return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0);
77}
78
79inline Uint8x16 LoadUint8x16(const Uint8x16* src) {
80 return __builtin_msa_ld_b(const_cast<Uint8x16*>(src), 0);
81}
82
83inline void StoreUint8x16(std::uint8_t* dst, Uint8x16 value) {
84 __builtin_msa_st_b(value, dst, 0);
85}
86
87inline void StoreUint8x16(Uint8x16* dst, Uint8x16 value) {
88 __builtin_msa_st_b(value, dst, 0);
89}
90
91template <int Lane>
92std::int32_t GetLane(Int32x4 value) {
93 return __builtin_msa_copy_s_w(value, Lane);
94}
95
96template <int Lane>
97Int32x4 DupLane(Int32x4 value) {
98 static_assert(Lane >= 0 && Lane <= 3, "");
99 return __builtin_msa_splati_w(value, Lane);
100}
101
102inline Int32x4 Mul(Int32x4 a, std::int32_t b) {
103 return __builtin_msa_mulv_w(a, __builtin_msa_fill_w(b));
104}
105
106inline Int32x4 Min(Int32x4 a, Int32x4 b) { return __builtin_msa_min_s_w(a, b); }
107
108inline Int32x4 Max(Int32x4 a, Int32x4 b) { return __builtin_msa_max_s_w(a, b); }
109
110inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
111 return __builtin_msa_mulr_q_w(a, __builtin_msa_fill_w(b));
112}
113
114template <int Lane>
115Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
116 static_assert(Lane >= 0 && Lane <= 3, "");
117 return __builtin_msa_mulv_w(a, __builtin_msa_splati_w(b, Lane));
118}
119
120static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
121 // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
122#if 0
123 return __builtin_msa_maddv_w(a, b, c);
124#else
125 asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
126 // Outputs
127 : [a] "+f"(a)
128 // Inputs
129 : [b] "f"(b), [c] "f"(c));
130 return a;
131#endif
132}
133
134inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
135 Int32x4 tmp = LoadInt32x4(acc);
136 tmp = workaround_msa_maddv_w(tmp, lhs, rhs);
137 StoreInt32x4(acc, tmp);
138}
139
140inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
141 Int32x4 tmp = LoadInt32x4(acc);
142 tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_fill_w(rhs));
143 StoreInt32x4(acc, tmp);
144}
145
146template <int Lane>
147inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
148 static_assert(Lane >= 0 && Lane <= 3, "");
149 Int32x4 tmp = LoadInt32x4(acc);
150 tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_splati_w(rhs, Lane));
151 StoreInt32x4(acc, tmp);
152}
153
154template <>
155struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
156 static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
157 RegBlockUint8<8, 8> result;
158 for (int i = 0; i < 4; i++) {
159 result.buf.reg[i] = LoadUint8x16(src + 16 * i);
160 }
161 return result;
162 }
163};
164
165template <>
166struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
167 static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
168 RegBlockInt32<8, 8> result;
169 for (int i = 0; i < 16; i++) {
170 result.buf.reg[i] = LoadInt32x4(src + 4 * i);
171 }
172 return result;
173 }
174};
175
176template <>
177struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
178 static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
179 RegBlockInt16<8, 8> result;
180 for (int i = 0; i < 8; i++) {
181 result.buf.reg[i] = LoadInt16x8(src + 8 * i);
182 }
183 return result;
184 }
185};
186
187} // end namespace gemmlowp
188
189#include "simd_wrappers_common_neon_sse.h"
190
191#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_