Miao Wang | a9fd919 | 2017-07-06 11:06:31 -0700 | [diff] [blame] | 1 | // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | // multi_thread_gemv.h: Entry point to the multithreaded version of the |
| 16 | // generated (meta) gemv library. |
| 17 | |
| 18 | #ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_ |
| 19 | #define GEMMLOWP_META_MULTI_THREAD_GEMV_H_ |
| 20 | |
| 21 | #ifdef GEMMLOWP_NEON |
| 22 | |
| 23 | #include "legacy_multi_thread_common.h" |
| 24 | #include "legacy_operations_common.h" |
| 25 | #include "legacy_single_thread_gemm.h" |
| 26 | |
| 27 | namespace gemmlowp { |
| 28 | namespace meta { |
| 29 | namespace internal { |
| 30 | |
| 31 | class GemvQuantized8BitOperation : public Quantized8BitOperation { |
| 32 | public: |
| 33 | GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, |
| 34 | std::int32_t sum_offset, std::int32_t multiplier, |
| 35 | std::int32_t shift) |
| 36 | : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier, |
| 37 | shift) {} |
| 38 | |
| 39 | void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, |
| 40 | const std::uint8_t* rhs, std::int32_t m, |
| 41 | std::int32_t n, std::int32_t k, std::uint8_t* result, |
| 42 | std::int32_t result_stride) const { |
| 43 | gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset, |
| 44 | multiplier, shift, result); |
| 45 | } |
| 46 | |
| 47 | static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, |
| 48 | std::int32_t k) { |
| 49 | return 128 * 1024; |
| 50 | } |
| 51 | }; |
| 52 | |
| 53 | class GemvFloatOperation : public FloatOperation { |
| 54 | public: |
| 55 | GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, |
| 56 | float result_offset) |
| 57 | : FloatOperation(lhs_offset, rhs_offset, result_offset) {} |
| 58 | |
| 59 | void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, |
| 60 | const std::uint8_t* rhs, std::int32_t m, |
| 61 | std::int32_t n, std::int32_t k, float* result, |
| 62 | std::int32_t result_stride) const { |
| 63 | gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset, |
| 64 | result); |
| 65 | } |
| 66 | |
| 67 | static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, |
| 68 | std::int32_t k) { |
| 69 | return 128 * 1024; |
| 70 | } |
| 71 | }; |
| 72 | |
| 73 | class GemvInt32Operation : public Int32Operation { |
| 74 | public: |
| 75 | GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset) |
| 76 | : Int32Operation(lhs_offset, rhs_offset) {} |
| 77 | |
| 78 | void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, |
| 79 | const std::uint8_t* rhs, std::int32_t m, |
| 80 | std::int32_t n, std::int32_t k, std::int32_t* result, |
| 81 | std::int32_t result_stride) const { |
| 82 | gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result); |
| 83 | } |
| 84 | |
| 85 | static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, |
| 86 | std::int32_t k) { |
| 87 | return 128 * 1024; |
| 88 | } |
| 89 | }; |
| 90 | |
| 91 | } // namespace internal |
| 92 | |
| 93 | std::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k, |
| 94 | std::int32_t max_threads) { |
| 95 | return internal::ResolveMaxThreads(max_threads) * |
| 96 | internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k); |
| 97 | } |
| 98 | |
| 99 | void multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads, |
| 100 | std::uint8_t* scratch, const std::uint8_t* lhs, |
| 101 | const std::uint8_t* rhs, std::int32_t n, |
| 102 | std::int32_t k, std::int32_t lhs_offset, |
| 103 | std::int32_t rhs_offset, std::int32_t sum_offset, |
| 104 | std::int32_t multiplier, std::int32_t shift, |
| 105 | std::uint8_t* result) { |
| 106 | max_threads = internal::ResolveMaxThreads(max_threads); |
| 107 | internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset, |
| 108 | sum_offset, multiplier, shift); |
| 109 | if (max_threads == 1) { |
| 110 | operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); |
| 111 | } else { |
| 112 | internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, |
| 113 | n, k, result, n, operation); |
| 114 | } |
| 115 | } |
| 116 | |
| 117 | std::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k, |
| 118 | std::int32_t max_threads) { |
| 119 | return internal::ResolveMaxThreads(max_threads) * |
| 120 | internal::GemvFloatOperation::ScratchPerThread(m, n, k); |
| 121 | } |
| 122 | |
| 123 | void multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads, |
| 124 | std::uint8_t* scratch, const std::uint8_t* lhs, |
| 125 | const std::uint8_t* rhs, std::int32_t n, |
| 126 | std::int32_t k, std::int32_t lhs_offset, |
| 127 | std::int32_t rhs_offset, float result_offset, |
| 128 | float* result) { |
| 129 | max_threads = internal::ResolveMaxThreads(max_threads); |
| 130 | internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset); |
| 131 | if (max_threads == 1) { |
| 132 | operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); |
| 133 | } else { |
| 134 | internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, |
| 135 | n, k, result, n, operation); |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | std::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k, |
| 140 | std::int32_t max_threads) { |
| 141 | return internal::ResolveMaxThreads(max_threads) * |
| 142 | internal::GemvInt32Operation::ScratchPerThread(m, n, k); |
| 143 | } |
| 144 | |
| 145 | void multi_thread_gemv_i32(gemmlowp::WorkersPool* pool, |
| 146 | std::int32_t max_threads, std::uint8_t* scratch, |
| 147 | const std::uint8_t* lhs, const std::uint8_t* rhs, |
| 148 | std::int32_t n, std::int32_t k, |
| 149 | std::int32_t lhs_offset, std::int32_t rhs_offset, |
| 150 | std::int32_t* result) { |
| 151 | max_threads = internal::ResolveMaxThreads(max_threads); |
| 152 | internal::GemvInt32Operation operation(lhs_offset, rhs_offset); |
| 153 | if (max_threads == 1) { |
| 154 | operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); |
| 155 | } else { |
| 156 | internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, |
| 157 | n, k, result, n, operation); |
| 158 | } |
| 159 | } |
| 160 | |
| 161 | } // namespace meta |
| 162 | } // namespace gemmlowp |
| 163 | |
| 164 | #else |
| 165 | #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!" |
| 166 | #endif |
| 167 | |
| 168 | #endif // GEMMLOWP_META_MULTI_THREAD_GEMV_H_ |