blob: 7af5684baa098ac99d9bdc05012e90371dac5288 [file] [log] [blame]
Miao Wanga9fd9192017-07-06 11:06:31 -07001// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// multi_thread_gemv.h: Entry point to the multithreaded version of the
16// generated (meta) gemv library.
17
18#ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_
19#define GEMMLOWP_META_MULTI_THREAD_GEMV_H_
20
21#ifdef GEMMLOWP_NEON
22
23#include "legacy_multi_thread_common.h"
24#include "legacy_operations_common.h"
25#include "legacy_single_thread_gemm.h"
26
27namespace gemmlowp {
28namespace meta {
29namespace internal {
30
31class GemvQuantized8BitOperation : public Quantized8BitOperation {
32 public:
33 GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
34 std::int32_t sum_offset, std::int32_t multiplier,
35 std::int32_t shift)
36 : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier,
37 shift) {}
38
39 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
40 const std::uint8_t* rhs, std::int32_t m,
41 std::int32_t n, std::int32_t k, std::uint8_t* result,
42 std::int32_t result_stride) const {
43 gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset,
44 multiplier, shift, result);
45 }
46
47 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
48 std::int32_t k) {
49 return 128 * 1024;
50 }
51};
52
53class GemvFloatOperation : public FloatOperation {
54 public:
55 GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
56 float result_offset)
57 : FloatOperation(lhs_offset, rhs_offset, result_offset) {}
58
59 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
60 const std::uint8_t* rhs, std::int32_t m,
61 std::int32_t n, std::int32_t k, float* result,
62 std::int32_t result_stride) const {
63 gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset,
64 result);
65 }
66
67 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
68 std::int32_t k) {
69 return 128 * 1024;
70 }
71};
72
73class GemvInt32Operation : public Int32Operation {
74 public:
75 GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
76 : Int32Operation(lhs_offset, rhs_offset) {}
77
78 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
79 const std::uint8_t* rhs, std::int32_t m,
80 std::int32_t n, std::int32_t k, std::int32_t* result,
81 std::int32_t result_stride) const {
82 gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result);
83 }
84
85 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
86 std::int32_t k) {
87 return 128 * 1024;
88 }
89};
90
91} // namespace internal
92
93std::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
94 std::int32_t max_threads) {
95 return internal::ResolveMaxThreads(max_threads) *
96 internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k);
97}
98
99void multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
100 std::uint8_t* scratch, const std::uint8_t* lhs,
101 const std::uint8_t* rhs, std::int32_t n,
102 std::int32_t k, std::int32_t lhs_offset,
103 std::int32_t rhs_offset, std::int32_t sum_offset,
104 std::int32_t multiplier, std::int32_t shift,
105 std::uint8_t* result) {
106 max_threads = internal::ResolveMaxThreads(max_threads);
107 internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset,
108 sum_offset, multiplier, shift);
109 if (max_threads == 1) {
110 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
111 } else {
112 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
113 n, k, result, n, operation);
114 }
115}
116
117std::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
118 std::int32_t max_threads) {
119 return internal::ResolveMaxThreads(max_threads) *
120 internal::GemvFloatOperation::ScratchPerThread(m, n, k);
121}
122
123void multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
124 std::uint8_t* scratch, const std::uint8_t* lhs,
125 const std::uint8_t* rhs, std::int32_t n,
126 std::int32_t k, std::int32_t lhs_offset,
127 std::int32_t rhs_offset, float result_offset,
128 float* result) {
129 max_threads = internal::ResolveMaxThreads(max_threads);
130 internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset);
131 if (max_threads == 1) {
132 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
133 } else {
134 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
135 n, k, result, n, operation);
136 }
137}
138
139std::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
140 std::int32_t max_threads) {
141 return internal::ResolveMaxThreads(max_threads) *
142 internal::GemvInt32Operation::ScratchPerThread(m, n, k);
143}
144
145void multi_thread_gemv_i32(gemmlowp::WorkersPool* pool,
146 std::int32_t max_threads, std::uint8_t* scratch,
147 const std::uint8_t* lhs, const std::uint8_t* rhs,
148 std::int32_t n, std::int32_t k,
149 std::int32_t lhs_offset, std::int32_t rhs_offset,
150 std::int32_t* result) {
151 max_threads = internal::ResolveMaxThreads(max_threads);
152 internal::GemvInt32Operation operation(lhs_offset, rhs_offset);
153 if (max_threads == 1) {
154 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
155 } else {
156 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
157 n, k, result, n, operation);
158 }
159}
160
161} // namespace meta
162} // namespace gemmlowp
163
164#else
165#warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
166#endif
167
168#endif // GEMMLOWP_META_MULTI_THREAD_GEMV_H_