| // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ |
| #define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ |
| |
| #include <iostream> |
| #include "base.h" |
| |
| namespace gemmlowp { |
| namespace meta { |
| |
| template <typename Executor, typename Params, int kernel_m, int kernel_n, |
| int kernel_k> |
| void Gemm(const Params& params); |
| |
| class GemmExecutorPackRHS { |
| public: |
| template <typename P> |
| static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
| int kernel_k) { |
| const int lhs_scratch = |
| StreamUtil<typename P::InType, typename P::LeftStream>::Scratch( |
| params.left_stream, kernel_m, kernel_k); |
| const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n); |
| const int rhs_scratch = |
| rhs_chunks * |
| StreamUtil<typename P::InType, typename P::RightStream>::Scratch( |
| params.right_stream, kernel_n, kernel_k); |
| return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); |
| } |
| |
| template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
| int k_leftovers> |
| static void ExecuteDispatch3D(const P& params) { |
| // Shorthand typedefs for streams and multiply kernels. |
| typedef typename P::InType InType; |
| typedef typename P::OutType OutType; |
| |
| typedef Stream<typename P::InType, m, k, k_leftovers, |
| typename P::LeftStream> |
| LeftStreamF; |
| typedef Stream<typename P::InType, m_leftovers, k, k_leftovers, |
| typename P::LeftStream> |
| LeftStreamL; |
| |
| typedef Stream<typename P::InType, n, k, k_leftovers, |
| typename P::RightStream> |
| RightStreamF; |
| typedef Stream<typename P::InType, n_leftovers, k, k_leftovers, |
| typename P::RightStream> |
| RightStreamL; |
| |
| typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream> |
| OutputStreamFF; |
| typedef Stream<typename P::OutType, m_leftovers, n, 0, |
| typename P::OutputStream> |
| OutputStreamLF; |
| |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m, n, k> |
| KernelFF; |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m, |
| n_leftovers, k> |
| KernelFL; |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m_leftovers, |
| n, k> |
| KernelLF; |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m_leftovers, |
| n_leftovers, k> |
| KernelLL; |
| |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n |
| << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" |
| << k_leftovers << " -- " << params.m << "x" << params.n << "x" |
| << params.k << std::endl; |
| LeftStreamF::Debug(params.left_stream); |
| LeftStreamL::Debug(params.left_stream); |
| |
| RightStreamF::Debug(params.right_stream); |
| RightStreamL::Debug(params.right_stream); |
| |
| OutputStreamFF::Debug(params.fused_kernel.output_stream); |
| OutputStreamLF::Debug(params.fused_kernel.output_stream); |
| |
| KernelFF::Debug(params.fused_kernel); |
| KernelFL::Debug(params.fused_kernel); |
| KernelLF::Debug(params.fused_kernel); |
| KernelLL::Debug(params.fused_kernel); |
| #endif |
| #endif |
| |
| int lhs_chunks = params.m / m; |
| int rhs_chunks = params.n / n; |
| |
| // Scratch memory for packed LHS & RHS chunks. |
| |
| std::uint8_t* packed_lhs = params.scratch; |
| std::uint8_t* packed_rhs = |
| params.scratch + LeftStreamF::Scratch(params.left_stream); |
| |
| // Pack full RHS first. |
| |
| std::uint8_t* packed_rhs_chunk = packed_rhs; |
| const int packed_rhs_chunk_size = |
| RightStreamF::PackedStride(params.right_stream); |
| |
| { |
| const std::uint8_t* rhs_chunk = |
| reinterpret_cast<const std::uint8_t*>(params.rhs); |
| const int rhs_chunk_size = |
| RightStreamF::UnpackedStride(params.right_stream); |
| |
| for (int i = 0; i < rhs_chunks; ++i) { |
| RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
| params.right_stream, |
| reinterpret_cast<InType*>(packed_rhs_chunk)); |
| |
| rhs_chunk += rhs_chunk_size; |
| packed_rhs_chunk += packed_rhs_chunk_size; |
| } |
| |
| RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
| params.right_stream, |
| reinterpret_cast<InType*>(packed_rhs_chunk)); |
| } |
| |
| // Multiply RHS by LHS one LHS chunk at a time. |
| |
| const std::uint8_t* lhs_chunk = |
| reinterpret_cast<const std::uint8_t*>(params.lhs); |
| std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); |
| std::uint8_t* result_chunk = result_strip; |
| |
| { |
| const int lhs_chunk_size = |
| LeftStreamF::UnpackedStride(params.left_stream); |
| const int result_strip_size = |
| OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); |
| const int result_chunk_size = |
| OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); |
| |
| for (int i = 0; i < lhs_chunks; ++i) { |
| LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
| params.left_stream, |
| reinterpret_cast<InType*>(packed_lhs)); |
| |
| result_chunk = result_strip; |
| packed_rhs_chunk = packed_rhs; |
| |
| for (int j = 0; j < rhs_chunks; ++j) { |
| KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
| reinterpret_cast<const InType*>(packed_rhs_chunk), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| |
| result_chunk += result_chunk_size; |
| packed_rhs_chunk += packed_rhs_chunk_size; |
| } |
| |
| KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
| reinterpret_cast<const InType*>(packed_rhs_chunk), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| |
| lhs_chunk += lhs_chunk_size; |
| result_strip += result_strip_size; |
| } |
| } |
| |
| // Leftover LHS chunk. |
| if (m_leftovers > 0) { // static if |
| const int result_chunk_size = |
| OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream); |
| |
| LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
| params.left_stream, |
| reinterpret_cast<InType*>(packed_lhs)); |
| |
| result_chunk = result_strip; |
| packed_rhs_chunk = packed_rhs; |
| |
| for (int i = 0; i < rhs_chunks; ++i) { |
| KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
| reinterpret_cast<const InType*>(packed_rhs_chunk), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| |
| result_chunk += result_chunk_size; |
| packed_rhs_chunk += packed_rhs_chunk_size; |
| } |
| |
| KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs), |
| reinterpret_cast<const InType*>(packed_rhs_chunk), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| } |
| } |
| }; |
| |
| class GemmExecutorPackLHS { |
| public: |
| template <typename P> |
| static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
| int kernel_k) { |
| const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m); |
| const int lhs_scratch = |
| lhs_chunks * |
| StreamUtil<typename P::InType, typename P::LeftStream>::Scratch( |
| params.left_stream, kernel_m, kernel_k); |
| const int rhs_scratch = |
| StreamUtil<typename P::InType, typename P::RightStream>::Scratch( |
| params.right_stream, kernel_n, kernel_k); |
| return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); |
| } |
| |
| template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
| int k_leftovers> |
| static void ExecuteDispatch3D(const P& params) { |
| // Shorthand typedefs for streams and multiply kernels. |
| typedef typename P::InType InType; |
| typedef typename P::OutType OutType; |
| |
| typedef Stream<typename P::InType, m, k, k_leftovers, |
| typename P::LeftStream> |
| LeftStreamF; |
| typedef Stream<typename P::InType, m_leftovers, k, k_leftovers, |
| typename P::LeftStream> |
| LeftStreamL; |
| |
| typedef Stream<typename P::InType, n, k, k_leftovers, |
| typename P::RightStream> |
| RightStreamF; |
| typedef Stream<typename P::InType, n_leftovers, k, k_leftovers, |
| typename P::RightStream> |
| RightStreamL; |
| |
| typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream> |
| OutputStreamFF; |
| typedef Stream<typename P::OutType, m, n_leftovers, 0, |
| typename P::OutputStream> |
| OutputStreamFL; |
| |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m, n, k> |
| KernelFF; |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m, |
| n_leftovers, k> |
| KernelFL; |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m_leftovers, |
| n, k> |
| KernelLF; |
| typedef MulKernel<typename P::InType, typename P::OutType, |
| typename P::Kernel, typename P::OutputStream, m_leftovers, |
| n_leftovers, k> |
| KernelLL; |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n |
| << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" |
| << k_leftovers << " -- " << params.m << "x" << params.n << "x" |
| << params.k << std::endl; |
| LeftStreamF::Debug(params.left_stream); |
| LeftStreamL::Debug(params.left_stream); |
| |
| RightStreamF::Debug(params.right_stream); |
| RightStreamL::Debug(params.right_stream); |
| |
| OutputStreamFF::Debug(params.fused_kernel.output_stream); |
| OutputStreamFL::Debug(params.fused_kernel.output_stream); |
| |
| KernelFF::Debug(params.fused_kernel); |
| KernelFL::Debug(params.fused_kernel); |
| KernelLF::Debug(params.fused_kernel); |
| KernelLL::Debug(params.fused_kernel); |
| #endif |
| #endif |
| |
| int lhs_chunks = params.m / m; |
| int rhs_chunks = params.n / n; |
| |
| // Scratch memory for packed LHS & RHS chunks. |
| std::uint8_t* packed_rhs = params.scratch; |
| std::uint8_t* packed_lhs = |
| params.scratch + RightStreamF::Scratch(params.right_stream); |
| |
| // Pack full LHS first. |
| |
| std::uint8_t* packed_lhs_chunk = packed_lhs; |
| const int packed_lhs_chunk_size = |
| LeftStreamF::PackedStride(params.left_stream); |
| |
| { |
| const std::uint8_t* lhs_chunk = |
| reinterpret_cast<const std::uint8_t*>(params.lhs); |
| const int lhs_chunk_size = |
| LeftStreamF::UnpackedStride(params.left_stream); |
| |
| for (int i = 0; i < lhs_chunks; ++i) { |
| LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
| params.left_stream, |
| reinterpret_cast<InType*>(packed_lhs_chunk)); |
| |
| lhs_chunk += lhs_chunk_size; |
| packed_lhs_chunk += packed_lhs_chunk_size; |
| } |
| |
| LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk), |
| params.left_stream, |
| reinterpret_cast<InType*>(packed_lhs_chunk)); |
| } |
| |
| // Multiply RHS by LHS one RHS chunk at a time. |
| |
| const std::uint8_t* rhs_chunk = |
| reinterpret_cast<const std::uint8_t*>(params.rhs); |
| std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); |
| std::uint8_t* result_chunk = result_strip; |
| |
| { |
| const int rhs_chunk_size = |
| RightStreamF::UnpackedStride(params.right_stream); |
| const int result_strip_size = |
| OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); |
| const int result_chunk_size = |
| OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); |
| |
| for (int i = 0; i < rhs_chunks; ++i) { |
| RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
| params.right_stream, |
| reinterpret_cast<InType*>(packed_rhs)); |
| |
| result_chunk = result_strip; |
| packed_lhs_chunk = packed_lhs; |
| |
| for (int j = 0; j < lhs_chunks; ++j) { |
| KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
| reinterpret_cast<const InType*>(packed_rhs), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| |
| result_chunk += result_chunk_size; |
| packed_lhs_chunk += packed_lhs_chunk_size; |
| } |
| |
| KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
| reinterpret_cast<const InType*>(packed_rhs), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| |
| rhs_chunk += rhs_chunk_size; |
| result_strip += result_strip_size; |
| } |
| } |
| |
| // Leftover RHS chunk. |
| if (n_leftovers > 0) { // static if |
| const int result_chunk_size = |
| OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream); |
| |
| RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk), |
| params.right_stream, |
| reinterpret_cast<InType*>(packed_rhs)); |
| |
| result_chunk = result_strip; |
| packed_lhs_chunk = packed_lhs; |
| |
| for (int i = 0; i < lhs_chunks; ++i) { |
| KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
| reinterpret_cast<const InType*>(packed_rhs), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| |
| result_chunk += result_chunk_size; |
| packed_lhs_chunk += packed_lhs_chunk_size; |
| } |
| |
| KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), |
| reinterpret_cast<const InType*>(packed_rhs), |
| params.fused_kernel, |
| reinterpret_cast<OutType*>(result_chunk)); |
| } |
| } |
| }; |
| |
| namespace internal { |
| |
| inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory, |
| int per_chunk_memory, int total_dim, |
| int chunk_dim) { |
| assert(constant_memory + per_chunk_memory < cache_size); |
| const int available_cache = cache_size - constant_memory; |
| const int available_chunks = available_cache / per_chunk_memory; |
| const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim; |
| return (chunks_count + available_chunks - 1) / available_chunks; |
| } |
| |
| template <typename Params> |
| inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n, |
| const Params& params, Params* task_params) { |
| task_params->m = m; |
| task_params->lhs = |
| StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset( |
| params.left_stream, params.lhs, m_offset, 0); |
| |
| task_params->n = n; |
| task_params->rhs = |
| StreamUtil<typename Params::InType, typename Params::RightStream>::Offset( |
| params.right_stream, params.rhs, n_offset, 0); |
| |
| task_params->result = |
| StreamUtil<typename Params::OutType, typename Params::OutputStream>:: |
| Offset(params.fused_kernel.output_stream, params.result, m_offset, |
| n_offset); |
| } |
| |
| } // namespace internal |
| |
| template <int cache_size = 256 * 1024> |
| class GemmExecutorPackRHSCacheFriendly { |
| public: |
| template <typename P> |
| static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
| int kernel_k) { |
| return cache_size; |
| } |
| |
| template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
| int k_leftovers> |
| static void ExecuteDispatch3D(const P& params) { |
| typedef Stream<typename P::InType, m, k, k_leftovers, |
| typename P::LeftStream> |
| LeftStream; |
| |
| typedef Stream<typename P::InType, n, k, k_leftovers, |
| typename P::RightStream> |
| RightStream; |
| |
| const int lhs_scratch = LeftStream::Scratch(params.left_stream); |
| const int rhs_scratch = RightStream::Scratch(params.right_stream); |
| |
| const int cache_friendly_tasks_count = |
| internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch, |
| rhs_scratch, params.n, n); |
| |
| if (cache_friendly_tasks_count == 1) { |
| GemmExecutorPackRHS::ExecuteDispatch3D<P, m, n, k, m_leftovers, |
| n_leftovers, k_leftovers>(params); |
| return; |
| } |
| |
| const int cache_friendly_dim = params.n / cache_friendly_tasks_count; |
| |
| P task_params = params; |
| for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { |
| internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim, |
| cache_friendly_dim, params, |
| &task_params); |
| Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params); |
| } |
| const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; |
| internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum, |
| params, &task_params); |
| Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params); |
| } |
| }; |
| |
| template <int cache_size = 256 * 1024> |
| class GemmExecutorPackLHSCacheFriendly { |
| public: |
| template <typename P> |
| static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, |
| int kernel_k) { |
| return cache_size; |
| } |
| |
| template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, |
| int k_leftovers> |
| static void ExecuteDispatch3D(const P& params) { |
| typedef Stream<typename P::InType, m, k, k_leftovers, |
| typename P::LeftStream> |
| LeftStream; |
| |
| typedef Stream<typename P::InType, n, k, k_leftovers, |
| typename P::RightStream> |
| RightStream; |
| |
| const int lhs_scratch = LeftStream::Scratch(params.left_stream); |
| const int rhs_scratch = RightStream::Scratch(params.right_stream); |
| |
| const int cache_friendly_tasks_count = |
| internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch, |
| lhs_scratch, params.m, m); |
| |
| if (cache_friendly_tasks_count == 1) { |
| GemmExecutorPackLHS::ExecuteDispatch3D<P, m, n, k, m_leftovers, |
| n_leftovers, k_leftovers>(params); |
| return; |
| } |
| |
| const int cache_friendly_dim = params.m / cache_friendly_tasks_count; |
| |
| P task_params = params; |
| for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { |
| internal::UpdateCacheFriendlyTask(i * cache_friendly_dim, |
| cache_friendly_dim, 0, params.n, params, |
| &task_params); |
| Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params); |
| } |
| const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; |
| internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n, |
| params, &task_params); |
| Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params); |
| } |
| }; |
| |
| namespace internal { |
| |
| // Stage 3. |
| |
| template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, |
| int fixed_n, int variable_k> |
| struct Dispatch3DStage3 { |
| static void Execute(const P& params, int k) { |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k |
| << " : " << fixed_m << "x" << fixed_n << "x" << variable_k |
| << std::endl |
| << std::flush; |
| #endif |
| #endif |
| if (k == variable_k) { |
| E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n, |
| variable_k>(params); |
| } else { |
| Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, |
| variable_k - 1>::Execute(params, k); |
| } |
| } |
| }; |
| |
| template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, |
| int fixed_n> |
| struct Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 0> { |
| static void Execute(const P& params, int k) { |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k |
| << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl |
| << std::flush; |
| #endif |
| #endif |
| if (k == 0) { |
| E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n, |
| 0>(params); |
| } else { |
| std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases." |
| << std::endl |
| << std::flush; |
| std::exit(1); |
| } |
| } |
| }; |
| |
| // Stage 2. |
| |
| template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, |
| int variable_n> |
| struct Dispatch3DStage2 { |
| static void Execute(const P& params, int n, int k) { |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k |
| << " : " << fixed_m << "x" << variable_n << std::endl |
| << std::flush; |
| #endif |
| #endif |
| if (n == variable_n) { |
| Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, variable_n, |
| dim_k - 1>::Execute(params, k); |
| } else { |
| Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, |
| variable_n - 1>::Execute(params, n, k); |
| } |
| } |
| }; |
| |
| template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m> |
| struct Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 0> { |
| static void Execute(const P& params, int n, int k) { |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k |
| << " : " << fixed_m << "x" << 0 << std::endl |
| << std::flush; |
| #endif |
| #endif |
| if (n == 0) { |
| Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, 0, |
| dim_k - 1>::Execute(params, k); |
| } else { |
| std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases." |
| << std::endl |
| << std::flush; |
| std::exit(1); |
| } |
| } |
| }; |
| |
| // Stage 1. |
| |
| template <typename E, typename P, int dim_m, int dim_n, int dim_k, |
| int variable_m> |
| struct Dispatch3DStage1 { |
| static void Execute(const P& params, int m, int n, int k) { |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k |
| << " : " << variable_m << std::endl |
| << std::flush; |
| #endif |
| #endif |
| if (m == variable_m) { |
| Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, variable_m, |
| dim_n - 1>::Execute(params, n, k); |
| } else { |
| Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, variable_m - 1>::Execute( |
| params, m, n, k); |
| } |
| } |
| }; |
| |
| template <typename E, typename P, int dim_m, int dim_n, int dim_k> |
| struct Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, 0> { |
| static void Execute(const P& params, int m, int n, int k) { |
| #ifdef DEBUG |
| #ifdef DEBUG_METAGEMM_VERBOSE |
| std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k |
| << " : " << 0 << std::endl |
| << std::flush; |
| #endif |
| #endif |
| if (m == 0) { |
| Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params, |
| n, k); |
| } else { |
| std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases." |
| << std::endl |
| << std::flush; |
| std::exit(1); |
| } |
| } |
| }; |
| |
| } // namespace internal |
| |
| template <typename Executor, typename Params, int kernel_m, int kernel_n, |
| int kernel_k> |
| inline void Gemm(const Params& params) { |
| internal::Dispatch3DStage1<Executor, Params, kernel_m, kernel_n, kernel_k, |
| kernel_m - 1>::Execute(params, params.m % kernel_m, |
| params.n % kernel_n, |
| params.k % kernel_k); |
| } |
| |
| } // namespace meta |
| } // namespace gemmlowp |
| |
| #endif // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ |