Rebase gemmlowp to 36ffd29

Test: mm
Test: build system image for sailfish
Test: BLAS CTS tests pass

Change-Id: I4cc9dbfd586f6653fc2d04e8e7ad78ada5d7dbe9
diff --git a/internal/allocator.h b/internal/allocator.h
index b0d7781..0fe4a01 100644
--- a/internal/allocator.h
+++ b/internal/allocator.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -43,10 +43,19 @@
 
 #if defined ANDROID || defined __ANDROID__
 #include <android/api-level.h>
-#if __ANDROID_API__ < 16
+// The 18 here should be 16, but has to be 18 for now due
+// to a Google-internal issue.
+#if __ANDROID_API__ < 18
 #include <malloc.h>
 #define GEMMLOWP_USE_MEMALIGN
 #endif
+// posix_memalign is missing on some 4.1 x86 devices
+#if __ANDROID_API__ == 18
+#ifdef GEMMLOWP_X86_32
+#include <malloc.h>
+#define GEMMLOWP_USE_MEMALIGN
+#endif
+#endif
 #endif
 
 namespace gemmlowp {
diff --git a/internal/block_params.h b/internal/block_params.h
index 48f93df..b2fc3ff 100644
--- a/internal/block_params.h
+++ b/internal/block_params.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -43,15 +43,19 @@
   int l2_depth;
 
   template <typename KernelFormat>
-  void Init(int rows, int cols, int depth, int num_threads) {
-    FindL2BlockSizes<KernelFormat>(rows, cols, depth, num_threads, &l2_rows,
-                                   &l2_cols, &l2_depth);
-    FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth, &l1_rows,
-                                   &l1_cols, &l1_depth);
+  void Init(int rows, int cols, int depth, int num_threads,
+            int l1_bytes_to_use, int l2_bytes_to_use, float l2_rhs_factor) {
+    FindL2BlockSizes<KernelFormat>(rows, cols, depth, num_threads,
+                                   l2_bytes_to_use, l2_rhs_factor,
+                                   &l2_rows, &l2_cols, &l2_depth);
+    FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth,
+                                   l1_bytes_to_use,
+                                   &l1_rows, &l1_cols, &l1_depth);
   }
 
   template <typename KernelFormat>
   static void FindL2BlockSizes(int rows, int cols, int depth, int num_threads,
+                               int l2_bytes_to_use, float l2_rhs_factor,
                                int* out_l2_rows, int* out_l2_cols,
                                int* out_l2_depth) {
     int l2_rows = 0;
@@ -64,9 +68,6 @@
     // of register size, so as to avoid having to special-case unaligned depths.
     l2_depth = RoundUp<kRegisterSize>(depth);
 
-    const int l2_bytes_to_use = kDefaultL2CacheSize;
-    const float l2_rhs_factor = kDefaultL2RhsFactor;
-
     {
       int max_cache_friendly_l2_cols = std::max(
           1, static_cast<int>(l2_rhs_factor * (l2_bytes_to_use / l2_depth)));
@@ -97,7 +98,8 @@
   }
 
   template <typename KernelFormat>
-  static void FindL1BlockSizes(int rows, int cols, int depth, int* out_l1_rows,
+  static void FindL1BlockSizes(int rows, int cols, int depth,
+                               int l1_bytes_to_use, int* out_l1_rows,
                                int* out_l1_cols, int* out_l1_depth) {
     int l1_rows = 0;
     int l1_cols = 0;
@@ -112,8 +114,6 @@
     // Thought not to be needed. Similar to Eigen.
     l1_cols = cols;
 
-    const int l1_bytes_to_use = kDefaultL1CacheSize;
-
     {
       int max_cache_friendly_l1_depth = std::max(
           1, (l1_bytes_to_use - 4 * KernelFormat::kRows * KernelFormat::kCols) /
diff --git a/internal/common.h b/internal/common.h
index 3d94041..1d89b26 100644
--- a/internal/common.h
+++ b/internal/common.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -86,20 +86,32 @@
 #define GEMMLOWP_NEON_64
 #endif
 
-// Detect SSE4.
-#if defined __SSE4_1__
+// Detect SSE.
+#ifdef __SSE4_1__
 #define GEMMLOWP_SSE4
 #endif
 
+#ifdef __SSE3__
+#define GEMMLOWP_SSE3
+#endif
+
 // Convenience SSE4 tokens for 32-bit or 64-bit
 #if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32)
 #define GEMMLOWP_SSE4_32
 #endif
 
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32)
+#define GEMMLOWP_SSE3_32
+#endif
+
 #if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64)
 #define GEMMLOWP_SSE4_64
 #endif
 
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_SSE3_64
+#endif
+
 #endif  // GEMMLOWP_ALLOW_INLINE_ASM
 
 // Detect Android. Don't conflate with ARM - we care about tuning
@@ -134,8 +146,13 @@
 // Of course, these values are in principle too low for typical x86 CPUs
 // where we should set the L2 value to (L3 cache size / number of cores) at
 // least.
-#if defined(GEMMLOWP_ARM) || defined(GEMMLOWP_ANDROID)
-// ARM or ARM-like hardware (Android implies ARM-like) so here it's OK
+//
+#if defined(GEMMLOWP_ARM) && defined(__APPLE__)
+// iPhone/iPad
+const int kDefaultL1CacheSize = 48 * 1024;
+const int kDefaultL2CacheSize = 2 * 1024 * 1024;
+#elif defined(GEMMLOWP_ARM) || defined(GEMMLOWP_ANDROID)
+// Other ARM or ARM-like hardware (Android implies ARM-like) so here it's OK
 // to tune for ARM, although on x86 Atom we might be able to query
 // cache sizes at runtime, which would be better.
 const int kDefaultL1CacheSize = 16 * 1024;
@@ -180,13 +197,17 @@
 // are consistent with this value.
 const int kRegisterSize = 16;
 
-// Requantization to less-than-8-bit is costly, so it only worth
-// doing if the GEMM width is large enough
-const int kMinimumWidthForRequantization = 100;
-
 // Hints the CPU to prefetch the cache line containing ptr.
 inline void Prefetch(const void* ptr) {
-#ifdef __GNUC__  // Clang and GCC define __GNUC__ and have __builtin_prefetch.
+#if defined GEMMLOWP_ARM_64 && defined GEMMLOWP_ALLOW_INLINE_ASM
+  // Aarch64 has very detailed prefetch instructions, that compilers
+  // can't know how to map __builtin_prefetch to, and as a result, don't,
+  // leaving __builtin_prefetch a no-op on this architecture.
+  // For our purposes, "pldl1keep" is usually what we want, meaning:
+  // "prefetch for load, into L1 cache, using each value multiple times".
+  asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) : );
+#elif defined \
+    __GNUC__  // Clang and GCC define __GNUC__ and have __builtin_prefetch.
   __builtin_prefetch(ptr);
 #else
   (void)ptr;
diff --git a/internal/compute.h b/internal/compute.h
index 4587df3..bbc9e2a 100644
--- a/internal/compute.h
+++ b/internal/compute.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -48,9 +48,11 @@
         packed_lhs_(_packed_lhs),
         packed_rhs_(_packed_rhs) {}
 
-  void Compute() {
-    for (int d = 0; d < block_params_.l2_depth; d += block_params_.l1_depth) {
-      int ds = std::min(block_params_.l1_depth, block_params_.l2_depth - d);
+  void Compute(int depth) {
+    depth = RoundUp<Format::kDepth>(depth);
+    assert(depth <= block_params_.l2_depth);
+    for (int d = 0; d < depth; d += block_params_.l1_depth) {
+      int ds = std::min(block_params_.l1_depth, depth - d);
 
       for (int r = 0; r < block_params_.l2_rows; r += block_params_.l1_rows) {
         int rs = std::min(block_params_.l1_rows, block_params_.l2_rows - r);
@@ -89,12 +91,12 @@
 template <typename PackedLhs, typename PackedRhs, typename PackedResult>
 void Compute(const KernelBase& kernel, const BlockParams& block_params,
              PackedResult* packed_result, const PackedLhs& packed_lhs,
-             const PackedRhs& packed_rhs) {
+             const PackedRhs& packed_rhs, int depth) {
   ScopedProfilingLabel label("compute");
   ComputeImpl<PackedLhs, PackedRhs, PackedResult> impl(
       kernel, block_params, packed_result, packed_lhs, packed_rhs);
 
-  impl.Compute();
+  impl.Compute(depth);
 }
 
 }  // namespace gemmlowp
diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h
new file mode 100644
index 0000000..0be0bf3
--- /dev/null
+++ b/internal/dispatch_gemm_shape.h
@@ -0,0 +1,189 @@
+// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
+
+#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
+#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
+
+#include "../internal/kernel_default.h"
+#include "../public/map.h"
+#include "../public/output_stages.h"
+#include "multi_thread_gemm.h"
+
+namespace gemmlowp {
+
+template <typename T>
+struct TransposeImpl {
+  typedef T DstType;
+  static T Run(const T& t) { return t; }
+};
+
+template <typename T>
+using TransposeType = typename TransposeImpl<T>::DstType;
+
+template <typename T>
+TransposeType<T> Transpose(const T& t) {
+  return TransposeImpl<T>::Run(t);
+}
+
+template <MapOrder Order>
+struct TransposeMapOrder {
+  static constexpr MapOrder Value =
+      Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
+};
+
+template <VectorShape Shape>
+struct TransposeVectorShape {
+  static constexpr VectorShape Value =
+      Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
+};
+
+template <typename Scalar, VectorShape Shape>
+struct TransposeImpl<VectorMap<Scalar, Shape>> {
+  typedef VectorMap<Scalar, Shape> SrcType;
+  static constexpr VectorShape TransposedShape =
+      TransposeVectorShape<Shape>::Value;
+  typedef VectorMap<Scalar, TransposedShape> DstType;
+  static DstType Run(const SrcType& src) {
+    return DstType(src.data(), src.size());
+  }
+};
+
+template <typename Scalar, MapOrder Order>
+struct TransposeImpl<MatrixMap<Scalar, Order>> {
+  typedef MatrixMap<Scalar, Order> SrcType;
+  static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
+  typedef MatrixMap<Scalar, TransposedOrder> DstType;
+  static DstType Run(const SrcType& src) {
+    return DstType(src.data(), src.cols(), src.rows(), src.stride());
+  }
+};
+
+template <VectorShape Shape>
+struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
+  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
+  static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
+  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
+  static DstType Run(const SrcType& src) {
+    DstType dst;
+    dst.result_shift = src.result_shift;
+    dst.result_offset = Transpose(src.result_offset);
+    dst.result_mult_int = Transpose(src.result_mult_int);
+    return dst;
+  }
+};
+
+template <typename VectorMapType>
+struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
+  typedef OutputStageBiasAddition<VectorMapType> SrcType;
+  typedef TransposeType<VectorMapType> TransposedVectorMapType;
+  typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
+  static DstType Run(const SrcType& src) {
+    DstType dst;
+    dst.bias_vector = Transpose(src.bias_vector);
+    return dst;
+  }
+};
+
+// TODO(benoitjacob) - does anyone understand C++ variadic templates?
+// How to use them to implement TransposeTuple? Note: there are lots
+// of answers on StackOverflow but they seem to all involve either
+// C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
+inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
+
+template <typename T0>
+std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
+  return std::make_tuple(Transpose(std::get<0>(t)));
+}
+
+template <typename T0, typename T1>
+std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
+    const std::tuple<T0, T1>& t) {
+  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
+}
+
+template <typename T0, typename T1, typename T2>
+std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
+TransposeTuple(const std::tuple<T0, T1, T2>& t) {
+  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
+                         Transpose(std::get<2>(t)));
+}
+
+template <typename T0, typename T1, typename T2, typename T3>
+std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
+           TransposeType<T3>>
+TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
+  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
+                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4>
+std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
+           TransposeType<T3>, TransposeType<T4>>
+TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
+  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
+                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
+                         Transpose(std::get<4>(t)));
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+          typename T5>
+std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
+           TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
+TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
+  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
+                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
+                         Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
+}
+
+template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
+          MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
+          typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
+          typename GemmContextType>
+void DispatchGemmShape(GemmContextType* context,
+                       const MatrixMap<const InputScalar, LhsOrder>& lhs,
+                       const MatrixMap<const InputScalar, RhsOrder>& rhs,
+                       MatrixMap<OutputScalar, ResultOrder>* result,
+                       const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
+                       const OutputPipelineType& output_pipeline) {
+  assert(lhs.cols() == rhs.rows());
+
+  int rows = result->rows();
+  int cols = result->cols();
+  int depth = lhs.cols();
+
+  if (rows == 0 || cols == 0 || depth == 0) {
+    // Vacuous GEMM, return early to avoid having to deal with
+    // zero sizes below.
+    return;
+  }
+
+  if (rows < cols) {
+    auto transposed_result_map = Transpose(*result);
+    return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
+        context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
+        Transpose(rhs_offset), Transpose(lhs_offset),
+        TransposeTuple(output_pipeline));
+  }
+
+  typedef DefaultKernel<BitDepthParams> Kernel;
+  MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
+                  BitDepthParams>(context, Kernel(), lhs, rhs, result,
+                                  lhs_offset, rhs_offset, output_pipeline);
+}
+
+}  // end namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
diff --git a/internal/iterator.h b/internal/iterator.h
index 917694d..524cb80 100644
--- a/internal/iterator.h
+++ b/internal/iterator.h
@@ -34,8 +34,9 @@
 class ConstIterator<VectorMap<tScalar, tShape>> {
  public:
   typedef tScalar Scalar;
-  ConstIterator(const VectorMap<tScalar, tShape>& vector_map)
-      : pointer_(vector_map.data()) {}
+  ConstIterator(const VectorMap<tScalar, tShape>& vector_map,
+                const int start_offset)
+      : pointer_(vector_map.data() + start_offset) {}
   const Scalar operator*() const { return *pointer_; }
   const Scalar* get() const { return pointer_; }
   ConstIterator& operator+=(int inc) { pointer_ += inc; return *this; }
@@ -45,8 +46,9 @@
 
 template <typename tScalar, VectorShape tShape>
 ConstIterator<VectorMap<tScalar, tShape>> const_iterator(
-    const VectorMap<tScalar, tShape>& vector_map) {
-  return ConstIterator<VectorMap<tScalar, tShape>>(vector_map);
+    const VectorMap<tScalar, tShape>& vector_map,
+    const int start_offset) {
+  return ConstIterator<VectorMap<tScalar, tShape>>(vector_map, start_offset);
 }
 
 template <typename tScalar, VectorShape tShape> class VectorDup;
@@ -66,7 +68,8 @@
 
 template <typename tScalar, VectorShape tShape>
 ConstIterator<VectorDup<tScalar, tShape>> const_iterator(
-    const VectorDup<tScalar, tShape>& vector_map) {
+    const VectorDup<tScalar, tShape>& vector_map,
+    const int start_offset) {
   return ConstIterator<VectorDup<tScalar, tShape>>(vector_map);
 }
 
diff --git a/internal/kernel.h b/internal/kernel.h
index 1aceec7..4d006af 100644
--- a/internal/kernel.h
+++ b/internal/kernel.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -145,6 +145,12 @@
   static const int kCells = tCells;
   static const int kWidth = kCells * Cell::kWidth;
   static const int kDepth = Cell::kDepth;
+  typedef std::uint8_t Scalar;
+};
+
+template <typename tCellFormat, int tCells>
+struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
+  typedef std::int8_t Scalar;
 };
 
 // KernelFormat describes fully the input data layout that a kernel expects.
@@ -210,6 +216,19 @@
   virtual ~KernelBase() {}
 };
 
+template <typename KernelScalarType>
+struct ZeroPointInputValue {};
+
+template <>
+struct ZeroPointInputValue<std::uint8_t> {
+  static constexpr std::uint8_t kValue = 0;
+};
+
+template <>
+struct ZeroPointInputValue<std::int8_t> {
+  static constexpr std::uint8_t kValue = 128;
+};
+
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_KERNEL_H_
diff --git a/internal/kernel_default.h b/internal/kernel_default.h
index 22bf4d0..bba0093 100644
--- a/internal/kernel_default.h
+++ b/internal/kernel_default.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -20,56 +20,86 @@
 
 #include "../public/bit_depth.h"
 #include "common.h"
+#include "kernel_reference.h"
 
 namespace gemmlowp {
 
-enum class KernelFamily { Gemm, Gemv };
+template <bool MaxProductIsLessThan4096,
+          bool LhsAlwaysNonzero>
+struct DefaultKernelImpl {};
 
-template <KernelFamily Family, int ProductBits>
-struct DefaultKernelImpl : DefaultKernelImpl<Family, ProductBits + 1> {
-  static_assert(ProductBits <= 16, "Bit depth too large");
-};
+// Partial specialization implementing the logic that if we want to use
+// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall
+// back to a generic kernel not taking advantage of LhsAlwaysNonzero.
+template <bool LhsAlwaysNonzero>
+struct DefaultKernelImpl<true, LhsAlwaysNonzero>
+    : DefaultKernelImpl<false, LhsAlwaysNonzero> {};
 
-template <KernelFamily Family, typename BitDepthParams>
+// Partial specialization implementing the logic that if we want to use
+// a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we
+// fall back to a generic kernel not taking advantage of
+// MaxProductIsLessThan4096.
+template <bool MaxProductIsLessThan4096>
+struct DefaultKernelImpl<MaxProductIsLessThan4096, true>
+    : DefaultKernelImpl<MaxProductIsLessThan4096, false> {};
+
+template <typename BitDepthParams>
 struct DefaultKernel
-    : DefaultKernelImpl<Family, BitDepthParams::LhsBitDepth::kBits +
-                                    BitDepthParams::RhsBitDepth::kBits> {};
+    : DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue *
+                             BitDepthParams::RhsRange::kMaxValue <
+                         4096),
+                        (BitDepthParams::LhsRange::kMinValue > 0)> {};
 
 }  // end namespace gemmlowp
 
-#define GEMMLOWP_SET_DEFAULT_KERNEL(op, max_product_bits, kernel)           \
-  namespace gemmlowp {                                                      \
-  template <>                                                               \
-  struct DefaultKernelImpl<KernelFamily::op, max_product_bits> : kernel {}; \
+#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
+                                    LhsAlwaysNonzero, Kernel) \
+  namespace gemmlowp {                                        \
+  template <>                                                 \
+  struct DefaultKernelImpl<MaxProductIsLessThan4096,          \
+                           LhsAlwaysNonzero> : Kernel {};     \
   }
 
 #if defined GEMMLOWP_NEON_32
 #include "kernel_neon.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemm, 16, NEON_32_Kernel12x4Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemm, 12,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(true, false,
                             NEON_32_Kernel12x4Depth2Assuming12BitProducts)
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemv, 16, NEONKernel4Nx1Depth2<3>)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
+                            NEON_32bit_GEMM_Int8Operands_LhsNonzero)
 #elif defined GEMMLOWP_NEON_64
 #include "kernel_neon.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemm, 16, NEON_64_Kernel12x8Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemv, 16, NEONKernel4Nx1Depth2<3>)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
+                            NEON_64bit_GEMM_Int8Operands_LhsNonzero)
 #elif defined GEMMLOWP_SSE4_32
-#include "kernel_SSE.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemm, 16, SSE4_32_Kernel4x4Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemv, 16, SSE4_32_Kernel4x4Depth2)
+#include "kernel_sse.h"
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
 #elif defined GEMMLOWP_SSE4_64
-#include "kernel_SSE.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemm, 16, SSE4_64_Kernel12x4Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemv, 16, SSE4_64_Kernel12x4Depth2)
+#include "kernel_sse.h"
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2)
 #else
+#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#if defined __ARM_ARCH_5TE__
+// SIMD is not available on this platform. The slow fallback will be used.
+// Don't require GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK because there's nothing
+// the user can do about it.
+#else
+#error \
+    "SIMD not enabled, you'd be getting a slow software fallback. Consider \
+enabling SIMD extensions (for example using -msse4 if you're on modern x86). \
+If that's not an option, and you would like to continue with the \
+slow fallback, define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK."
+#endif
+#endif
 #include "kernel_reference.h"
 namespace gemmlowp {
-typedef ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 4>, 2>,
-                                     KernelSideFormat<CellFormat<4, 4>, 2> > >
+typedef ReferenceKernel<KernelFormat<
+    KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+    KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > >
     DefaultReferenceKernel;
 }
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemm, 16, DefaultReferenceKernel)
-GEMMLOWP_SET_DEFAULT_KERNEL(Gemv, 16, DefaultReferenceKernel)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel)
 #endif
 
 #endif  // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
diff --git a/internal/kernel_neon.h b/internal/kernel_neon.h
index 74b5fec..5c253ba 100644
--- a/internal/kernel_neon.h
+++ b/internal/kernel_neon.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -49,30 +49,13 @@
 //  so use numerical ones instead. See
 // http://stackoverflow.com/questions/3898435/labels-in-gcc-inline-assembly
 // If you add any labels, remember to undef them at the end.
-#define GEMMLOWP_LOOP_NEON_KERNEL_12X4_DEPTH2 "1"
-#define GEMMLOWP_STORE_RESULT_NEON_KERNEL_12X4_DEPTH2 "2"
+#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
+#define GEMMLOWP_LABEL_BEFORE_LOOP "2"
+#define GEMMLOWP_LABEL_LOOP "3"
+#define GEMMLOWP_LABEL_AFTER_LOOP "4"
 
     assert(dst_row_stride == 1);
     asm volatile(
-        // Clear accumulator registers (see layout below)
-        "vmov.s32 q4, #0\n"
-        "vmov.s32 q8, q4\n"
-        "vmov.s32 q12, q4\n"
-        "vmov.s32 q5, q4\n"
-        "vmov.s32 q9, q4\n"
-        "vmov.s32 q13, q4\n"
-        "vmov.s32 q6, q4\n"
-        "vmov.s32 q10, q4\n"
-        "vmov.s32 q14, q4\n"
-        "vmov.s32 q7, q4\n"
-        "vmov.s32 q11, q4\n"
-        "vmov.s32 q15, q4\n"
-
-        /* Main loop */
-
-        GEMMLOWP_LOOP_NEON_KERNEL_12X4_DEPTH2
-        ":\n"
-
         // Overview of register layout:
         //
         // A 2x4 cell of Rhs is stored in 16bit in d0--d1 (q0).
@@ -110,12 +93,125 @@
         //                            Accumulator
 
         // Load 1 Rhs cell of size 2x4
-        "vld1.8 {d0}, [%[rhs_ptr]:64]!\n"
-
+        "vld1.8 {d0}, [%[rhs_ptr]]!\n"
         // Load 3 Lhs cells of size 4x2 each
-        "vld1.8 {d2}, [%[lhs_ptr]:64]!\n"
-        "vld1.8 {d4}, [%[lhs_ptr]:64]!\n"
-        "vld1.8 {d6}, [%[lhs_ptr]:64]!\n"
+        "vld1.8 {d2}, [%[lhs_ptr]]!\n"
+        "vld1.8 {d4}, [%[lhs_ptr]]!\n"
+        "vld1.8 {d6}, [%[lhs_ptr]]!\n"
+
+        // Check if start_depth==0 to decide whether we will clear
+        // accumulators or load existing accumulators.
+        "cmp %[start_depth], #0\n"
+
+        // Multiply dst_col_stride by 4 == sizeof(int32) to use
+        // it as a byte offset below.
+        "lsl %[dst_col_stride], #2\n"
+
+        "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+        "f\n"
+
+        // Load accumulators (start_depth != 0)
+        "mov r1, %[dst_ptr]\n"
+        "subs %[run_depth], #2\n"
+        "mov r0, r1\n"
+        "vld1.32 {d8, d9},   [r0]!\n"
+        "add r1, %[dst_col_stride]\n"
+        "vld1.32 {d16, d17}, [r0]!\n"
+        "vld1.32 {d24, d25}, [r0]\n"
+        "mov r0, r1\n"
+        "vld1.32 {d10, d11}, [r0]!\n"
+        "add r1, %[dst_col_stride]\n"
+        "vld1.32 {d18, d19}, [r0]!\n"
+        "vld1.32 {d26, d27}, [r0]\n"
+        "mov r0, r1\n"
+        "vld1.32 {d12, d13}, [r0]!\n"
+        "add r1, %[dst_col_stride]\n"
+        "vld1.32 {d20, d21}, [r0]!\n"
+        "vld1.32 {d28, d29}, [r0]\n"
+        "mov r0, r1\n"
+        "vld1.32 {d14, d15}, [r0]!\n"
+        "vld1.32 {d22, d23}, [r0]!\n"
+        "vld1.32 {d30, d31}, [r0]\n"
+
+        "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
+
+        GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+        ":\n"
+
+        // Clear accumulators (start_depth == 0)
+        "vmov.s32 q4, #0\n"
+        "subs %[run_depth], #2\n"
+        "vmov.s32 q8, q4\n"
+        "vmov.s32 q12, q4\n"
+        "vmov.s32 q5, q4\n"
+        "vmov.s32 q9, q4\n"
+        "vmov.s32 q13, q4\n"
+        "vmov.s32 q6, q4\n"
+        "vmov.s32 q10, q4\n"
+        "vmov.s32 q14, q4\n"
+        "vmov.s32 q7, q4\n"
+        "vmov.s32 q11, q4\n"
+        "vmov.s32 q15, q4\n"
+
+        GEMMLOWP_LABEL_BEFORE_LOOP
+        ":\n"
+
+        // If there are only two levels of depth, skip the loop.
+        "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
+
+        GEMMLOWP_LABEL_LOOP
+        ":\n"
+        // Expand Lhs/Rhs cells to 16 bit.
+        // Note: moving theses vmovls further down to allow for
+        // longer data pipelining helps a little on A57 but is
+        // harmful on A53 --- It looks as if A53 doesn't like
+        // interleaving vmovl's into the vmlal's.
+        "vmovl.u8 q0, d0\n"
+        "vmovl.u8 q1, d2\n"
+        "vmovl.u8 q2, d4\n"
+        "vmovl.u8 q3, d6\n"
+
+        // Multiply-accumulate, level of depth 0
+        "vmlal.u16 q4, d2, d0[0]\n"
+        "vmlal.u16 q5, d2, d0[1]\n"
+        "vmlal.u16 q6, d2, d0[2]\n"
+        "vmlal.u16 q7, d2, d0[3]\n"
+        "vldr d2, [%[lhs_ptr]]\n"
+        "vmlal.u16 q8, d4, d0[0]\n"
+        "vmlal.u16 q9, d4, d0[1]\n"
+        "vmlal.u16 q10, d4, d0[2]\n"
+        "vmlal.u16 q11, d4, d0[3]\n"
+        "vldr d4, [%[lhs_ptr], #8]\n"
+        "vmlal.u16 q12, d6, d0[0]\n"
+        "vmlal.u16 q13, d6, d0[1]\n"
+        "vmlal.u16 q14, d6, d0[2]\n"
+        "vmlal.u16 q15, d6, d0[3]\n"
+        "vldr d6, [%[lhs_ptr], #16]\n"
+        "vldr d0, [%[rhs_ptr]]\n"
+
+        // Multiply-accumulate, level of depth 1
+        "vmlal.u16 q4, d3, d1[0]\n"
+        "vmlal.u16 q5, d3, d1[1]\n"
+        "add %[lhs_ptr], #24\n"
+        "vmlal.u16 q6, d3, d1[2]\n"
+        "vmlal.u16 q7, d3, d1[3]\n"
+        "add %[rhs_ptr], #8\n"
+        "vmlal.u16 q8, d5, d1[0]\n"
+        "vmlal.u16 q9, d5, d1[1]\n"
+        "subs %[run_depth], #2\n"
+        "vmlal.u16 q10, d5, d1[2]\n"
+        "vmlal.u16 q11, d5, d1[3]\n"
+        "vmlal.u16 q12, d7, d1[0]\n"
+        "vmlal.u16 q13, d7, d1[1]\n"
+        "vmlal.u16 q14, d7, d1[2]\n"
+        "vmlal.u16 q15, d7, d1[3]\n"
+
+        "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+        GEMMLOWP_LABEL_AFTER_LOOP
+        ":\n"
+
+        // Do remaining arithmetic for the last 2 levels of depth.
 
         // Expand Lhs/Rhs cells to 16 bit.
         "vmovl.u8 q0, d0\n"
@@ -151,99 +247,27 @@
         "vmlal.u16 q14, d7, d1[2]\n"
         "vmlal.u16 q15, d7, d1[3]\n"
 
-        // Loop. Decrement loop index (depth) by 2, since we just handled 2
-        // levels of depth (Kernel::kDepth=2).
-        "subs %[run_depth], #2\n"
-        "bne " GEMMLOWP_LOOP_NEON_KERNEL_12X4_DEPTH2
-        "b\n"
-
-        /* end of main loop */
-
-        /* Accumulate our local accumulator registers into the destination block
-           */
-
-        // Compute stride between consecutive columns, in bytes
-        "mov r0, #4\n"  // multiply by 4 = sizeof(int32)
-        "mul %[dst_col_stride], r0\n"
-
-        // If start_depth == 0, then there is no preexisting accumulator
-        // to accumulate, so we can simply store our result.
-        "cmp %[start_depth], #0\n"
-        "beq " GEMMLOWP_STORE_RESULT_NEON_KERNEL_12X4_DEPTH2
-        "f\n"
-
-        "mov r0, %[dst_ptr]\n"
-
-        // Load a column
-        "mov r1, r0\n"
-        "vld1.32 {d0, d1}, [r1]!\n"
-        "vld1.32 {d2, d3}, [r1]!\n"
-        "vld1.32 {d4, d5}, [r1]!\n"
-        // Accumulate a column
-        "vadd.s32 q4, q4, q0\n"
-        "vadd.s32 q8, q8, q1\n"
-        "vadd.s32 q12, q12, q2\n"
-
-        "add r0, %[dst_col_stride]\n"
-        // Load a column
-        "mov r1, r0\n"
-        "vld1.32 {d0, d1}, [r1]!\n"
-        "vld1.32 {d2, d3}, [r1]!\n"
-        "vld1.32 {d4, d5}, [r1]!\n"
-        // Accumulate a column
-        "vadd.s32 q5, q5, q0\n"
-        "vadd.s32 q9, q9, q1\n"
-        "vadd.s32 q13, q13, q2\n"
-
-        "add r0, %[dst_col_stride]\n"
-        // Load a column
-        "mov r1, r0\n"
-        "vld1.32 {d0, d1}, [r1]!\n"
-        "vld1.32 {d2, d3}, [r1]!\n"
-        "vld1.32 {d4, d5}, [r1]!\n"
-        // Accumulate a column
-        "vadd.s32 q6, q6, q0\n"
-        "vadd.s32 q10, q10, q1\n"
-        "vadd.s32 q14, q14, q2\n"
-
-        "add r0, %[dst_col_stride]\n"
-        // Load a column
-        "mov r1, r0\n"
-        "vld1.32 {d0, d1}, [r1]!\n"
-        "vld1.32 {d2, d3}, [r1]!\n"
-        "vld1.32 {d4, d5}, [r1]!\n"
-        // Accumulate a column
-        "vadd.s32 q7, q7, q0\n"
-        "vadd.s32 q11, q11, q1\n"
-        "vadd.s32 q15, q15, q2\n"
-
-        GEMMLOWP_STORE_RESULT_NEON_KERNEL_12X4_DEPTH2
-        ":\n"
-
-        "mov r0, %[dst_ptr]\n"
-        // Store a column
-        "mov r1, r0\n"
-        "vst1.32 {d8, d9}, [r1]!\n"
-        "vst1.32 {d16, d17}, [r1]!\n"
-        "vst1.32 {d24, d25}, [r1]!\n"
-        // Store a column
-        "add r0, %[dst_col_stride]\n"
-        "mov r1, r0\n"
-        "vst1.32 {d10, d11}, [r1]!\n"
-        "vst1.32 {d18, d19}, [r1]!\n"
-        "vst1.32 {d26, d27}, [r1]!\n"
-        // Store a column
-        "add r0, %[dst_col_stride]\n"
-        "mov r1, r0\n"
-        "vst1.32 {d12, d13}, [r1]!\n"
-        "vst1.32 {d20, d21}, [r1]!\n"
-        "vst1.32 {d28, d29}, [r1]!\n"
-        // Store a column
-        "add r0, %[dst_col_stride]\n"
-        "mov r1, r0\n"
-        "vst1.32 {d14, d15}, [r1]!\n"
-        "vst1.32 {d22, d23}, [r1]!\n"
-        "vst1.32 {d30, d31}, [r1]!\n"
+        // Store accumulators
+        "mov r1, %[dst_ptr]\n"
+        "mov r0, r1\n"
+        "vst1.32 {d8, d9},   [r0]!\n"
+        "add r1, %[dst_col_stride]\n"
+        "vst1.32 {d16, d17}, [r0]!\n"
+        "vst1.32 {d24, d25}, [r0]\n"
+        "mov r0, r1\n"
+        "vst1.32 {d10, d11}, [r0]!\n"
+        "add r1, %[dst_col_stride]\n"
+        "vst1.32 {d18, d19}, [r0]!\n"
+        "vst1.32 {d26, d27}, [r0]\n"
+        "mov r0, r1\n"
+        "vst1.32 {d12, d13}, [r0]!\n"
+        "add r1, %[dst_col_stride]\n"
+        "vst1.32 {d20, d21}, [r0]!\n"
+        "vst1.32 {d28, d29}, [r0]\n"
+        "mov r0, r1\n"
+        "vst1.32 {d14, d15}, [r0]!\n"
+        "vst1.32 {d22, d23}, [r0]!\n"
+        "vst1.32 {d30, d31}, [r0]\n"
         :  // outputs
         [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
         [dst_ptr] "+r"(dst_ptr),
@@ -259,8 +283,10 @@
         "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
         "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
         "d31");
-#undef GEMMLOWP_LOOP_NEON_KERNEL_12X4_DEPTH2
-#undef GEMMLOWP_STORE_RESULT_NEON_KERNEL_12X4_DEPTH2
+#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+#undef GEMMLOWP_LABEL_BEFORE_LOOP
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP
   }
 };
 
@@ -638,11 +664,604 @@
   }
 };
 
+struct NEON_32bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
+  typedef KernelFormat<
+      KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+      KernelSideFormatInt8<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
+      Format;
+  const char* Name() const override {
+    return "NEON, 4x2, depth 16, accumulating two within signed int16";
+  }
+
+  // TODO(benoitjacob): reorder function arguments so dst comes last
+  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+           std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+           const std::uint8_t* rhs_ptr, std::size_t start_depth,
+           std::size_t run_depth) const override {
+#define GEMMLOWP_LABEL_AFTER_LOOP "1"
+#define GEMMLOWP_LABEL_LOOP "2"
+#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
+#define GEMMLOWP_LABEL_STORE "4"
+    asm volatile(
+        // Multiply dst_col_stride by 4 == sizeof(int32) to use
+        // it as a byte offset below.
+        "lsl %[dst_col_stride], %[dst_col_stride], #2\n"
+
+        // Overview of register layout:
+        //
+        // A 2x16 block of Rhs is stored in 8 bit in d0--d3.
+        // A 4x16 block of Lhs is stored in 8 bit in d4--d7. That is only
+        // half of the register space required, so we loop over these registers
+        // twice. Only half of it, a 2x16 block, is stored in d4--d7 at
+        // any given time.
+        //
+        // A 4x2 block of accumulators is stored in q8--q15 (as 4x32 bit
+        // components which need to be horizontally-added at the end)
+        //
+        // The Lhs vectors are multiplied by the Rhs vectors with a widening
+        // multiply over the 8 first levels of depth, producing int16x8
+        // vectors of products for each position in the accumulator matrix.
+        // Here comes the special trick: since the operands are signed int8,
+        // their range being [ -2^7 , 2^7 ), their products are in range
+        // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values
+        // without any risk of overflowing int16.
+        // We thus proceed with the 8 next levels of depth, multiplying
+        // again Lhs by Rhs, accumulating into this existing int16x8 vector.
+        //
+        // Only then, having processed 16 levels of depth, do we need to
+        // horizontally add these int16x8 accumulators into the final
+        // int32x4 accumulators.
+        //
+        // As we do not have enough registers to store all 16 int16x8
+        // temporary-16bit-accumulators, we have them cycle through q4--q7.
+        //
+        //
+        // Register layout (ignoring the q4--q7 temporary 16bit accumulators):
+        //
+        //                               +----+----+
+        //                               | d0 | d2 |
+        //                               | .  | .  |
+        //                               | .  | .  |
+        //                               | .  | .  |
+        //                       Rhs     +----+----+
+        //                               | d1 | d3 |
+        //                               | .  | .  |
+        //                               | .  | .  |
+        //                               | .  | .  |
+        //                               +----+----+
+        //
+        //                               |    |    |
+        //
+        //    Lhs                        |    |    |
+        //
+        //  +--------+--------+ - - - -  +----+----+
+        //  | d4 ... | d5 ... |          | q8 | q9 |
+        //  | d6 ... | d7 ... |          | q10| q11|
+        //  | d4 ... | d5 ... |          | q12| q13|
+        //  | d6 ... | d7 ... |          | q14| q15|
+        //  +--------+--------+ - - - -  +----+----+
+        //
+        //                               Accumulator
+        //
+
+        // Clear accumulators, and, interleaved with it,
+        // initial loads of the first loop iteration,
+        // taken out of the loop so that in the loop itself we have
+        // optimal streaming of data from memory.
+        "vldr d0, [%[rhs_ptr], #0]\n"
+        "vmov.i32 q8, #0\n"
+        "vldr d4, [%[lhs_ptr], #0]\n"
+        "vmov.i32 q9, #0\n"
+        "vldr d2, [%[rhs_ptr], #16]\n"
+        "vmov.i32 q10, q8\n"
+        "vldr d6, [%[lhs_ptr], #16]\n"
+        "vmov.i32 q11, q8\n"
+        "vldr d1, [%[rhs_ptr], #8]\n"
+        "vmov.i32 q12, q8\n"
+        "vldr d5, [%[lhs_ptr], #8]\n"
+        "vmov.i32 q13, q8\n"
+        "vldr d3, [%[rhs_ptr], #24]\n"
+        "vmov.i32 q14, q8\n"
+        "vldr d7, [%[lhs_ptr], #24]\n"
+        "vmov.i32 q15, q8\n"
+
+        // General loop.
+        GEMMLOWP_LABEL_LOOP
+        ":\n"
+
+        // Multiply 8 first levels of depth.
+        "vmull.s8    q4,  d0,  d4\n"
+        "add %[rhs_ptr], %[rhs_ptr], #32\n"
+        "vmull.s8    q5,  d2,  d4\n"
+        "vldr d4, [%[lhs_ptr], #32]\n"
+        "vmull.s8    q6,  d0,  d6\n"
+        "vmull.s8    q7,  d2,  d6\n"
+        "vldr d6, [%[lhs_ptr], #48]\n"
+
+        // Multiply-accumulate second-half, again into the same
+        // 16bit local accumulator registers. This is where we
+        // take advantage of having int8 instead of uint8 and therefore
+        // being able to accumulate two products into int16.
+        "vmlal.s8    q4,  d1,  d5\n"
+        "vmlal.s8    q5,  d3,  d5\n"
+        "vldr d5, [%[lhs_ptr], #40]\n"
+        "vmlal.s8    q6,  d1,  d7\n"
+        "vmlal.s8    q7,  d3,  d7\n"
+        "vldr d7, [%[lhs_ptr], #56]\n"
+
+        // Add pairwise, accumulate into 32-bit accumulators.
+        "vpadal.s16   q8,  q4\n"
+        "add %[lhs_ptr], %[lhs_ptr], #64\n"
+        "vpadal.s16   q9,  q5\n"
+        "subs %[run_depth], %[run_depth], #16\n"
+        "vpadal.s16   q10, q6\n"
+        "vpadal.s16   q11, q7\n"
+
+        "beq " GEMMLOWP_LABEL_AFTER_LOOP
+        "f\n"
+
+        // Multiply first half.
+        "vmull.s8    q4,  d0,  d4\n"
+        "vmull.s8    q5,  d2,  d4\n"
+        "vldr d4, [%[lhs_ptr], #0]\n"
+        "vmull.s8    q6,  d0,  d6\n"
+        "vldr d0, [%[rhs_ptr], #0]\n"
+        "vmull.s8    q7,  d2,  d6\n"
+        "vldr d2, [%[rhs_ptr], #16]\n"
+
+        // Multiply-accumulate second-half, again into the same
+        // 16bit local accumulator registers. This is where we
+        // take advantage of having int8 instead of uint8 and therefore
+        // being able to accumulate two products into int16.
+        "vmlal.s8    q4,  d1,  d5\n"
+        "vldr d6, [%[lhs_ptr], #16]\n"
+        "vmlal.s8    q5,  d3,  d5\n"
+        "vldr d5, [%[lhs_ptr], #8]\n"
+        "vmlal.s8    q6,  d1,  d7\n"
+        "vldr d1, [%[rhs_ptr], #8]\n"
+        "vmlal.s8    q7,  d3,  d7\n"
+        "vldr d3, [%[rhs_ptr], #24]\n"
+
+        // Add pairwise, accumulate into 32-bit accumulators.
+        "vpadal.s16   q12, q4\n"
+        "vldr d7, [%[lhs_ptr], #24]\n"
+        "vpadal.s16   q13, q5\n"
+        "vpadal.s16   q14, q6\n"
+        "vpadal.s16   q15, q7\n"
+
+        "b " GEMMLOWP_LABEL_LOOP "b\n"
+
+        GEMMLOWP_LABEL_AFTER_LOOP
+        ":\n"
+
+        // Multiply first half.
+        "vmull.s8    q4,  d0,  d4\n"
+        "vmull.s8    q5,  d2,  d4\n"
+        "vmull.s8    q6,  d0,  d6\n"
+        "vmull.s8    q7,  d2,  d6\n"
+
+        // Multiply-accumulate second-half, again into the same
+        // 16bit local accumulator registers. This is where we
+        // take advantage of having int8 instead of uint8 and therefore
+        // being able to accumulate two products into int16.
+        "vmlal.s8    q4,  d1,  d5\n"
+        "vmlal.s8    q5,  d3,  d5\n"
+        "vmlal.s8    q6,  d1,  d7\n"
+        "vmlal.s8    q7,  d3,  d7\n"
+
+        // Add pairwise, accumulate into 32-bit accumulators.
+        "vpadal.s16   q12, q4\n"
+        "vpadal.s16   q13, q5\n"
+        "vpadal.s16   q14, q6\n"
+        "vpadal.s16   q15, q7\n"
+        "cmp %[start_depth], #0\n"
+
+        // Reduce 32bit accumulators horizontally.
+        "vpadd.s32 d0, d16, d17\n"
+        "vpadd.s32 d1, d18, d19\n"
+        "vpadd.s32 d2, d20, d21\n"
+        "vpadd.s32 d3, d22, d23\n"
+        "vpadd.s32 d4, d24, d25\n"
+        "vpadd.s32 d5, d26, d27\n"
+        "vpadd.s32 d6, d28, d29\n"
+        "vpadd.s32 d7, d30, d31\n"
+
+        "bne " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+        "f\n"
+
+        // Reduce 32bit accumulators horizontally, second pass
+        // (each pass adds pairwise. we need to add 4-wise).
+        "vpadd.s32 d8, d0, d2\n"
+        "vpadd.s32 d9, d4, d6\n"
+        "vpadd.s32 d10, d1, d3\n"
+        "vpadd.s32 d11, d5, d7\n"
+
+        "b " GEMMLOWP_LABEL_STORE "f\n"
+
+        GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+        ":\n"
+
+        // Reduce 32bit accumulators horizontally, second pass
+        // (each pass adds pairwise. we need to add 4-wise),
+        // and load destination values from memory.
+        "mov r0, %[dst_ptr]\n"
+        "vld1.32 {d16, d17}, [r0], %[dst_col_stride]\n"
+        "vpadd.s32 d8, d0, d2\n"
+        "vpadd.s32 d9, d4, d6\n"
+        "vld1.32 {d18, d19}, [r0]\n"
+        "vpadd.s32 d10, d1, d3\n"
+        "vpadd.s32 d11, d5, d7\n"
+
+        // Add horizontally-reduced accumulators into
+        // the values loaded from memory
+        "vadd.s32 q4, q8, q4\n"
+        "vadd.s32 q5, q9, q5\n"
+
+        GEMMLOWP_LABEL_STORE
+        ":\n"
+        // Store back into memory
+        "mov r0, %[dst_ptr]\n"
+        "vst1.32 {d8, d9}, [r0], %[dst_col_stride]\n"
+        "vst1.32 {d10, d11}, [r0]\n"
+        :  // outputs
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+        [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth)
+        :  // inputs
+        [start_depth] "r"(start_depth),
+        [dst_col_stride] "r"(dst_col_stride)
+        :  // clobbers
+        "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
+        "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
+        "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
+        "d28", "d29", "d30", "d31");
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP
+#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+#undef GEMMLOWP_LABEL_STORE
+  }
+};
+
 #endif  // GEMMLOWP_NEON_32
 
 // The kernels here are specifically arm 64bit assembly, not arm 32bit.
 #ifdef GEMMLOWP_NEON_64
 
+struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
+  typedef KernelFormat<
+      KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+      KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
+      Format;
+  const char* Name() const override {
+    return "NEON, 4x4, depth 16, accumulating two within signed int16";
+  }
+
+  // TODO(benoitjacob): reorder function arguments so dst comes last
+  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+           std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+           const std::uint8_t* rhs_ptr, std::size_t start_depth,
+           std::size_t run_depth) const override {
+#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
+#define GEMMLOWP_LABEL_LOOP "2"
+#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
+#define GEMMLOWP_LABEL_STORE "4"
+    asm volatile(
+        // Clear accumulators, and, interleaved with it,
+        // initial loads of the first loop iteration,
+        // taken out of the loop so that in the loop itself we have
+        // optimal streaming of data from memory.
+        "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+        "dup v16.4s, wzr\n"
+        "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+        "dup v17.4s, wzr\n"
+        "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
+        "dup v18.4s, wzr\n"
+        "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+        "dup v19.4s, wzr\n"
+        "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+        "dup v20.4s, wzr\n"
+        "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+        "dup v21.4s, wzr\n"
+        "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+        "dup v22.4s, wzr\n"
+        "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+        "dup v23.4s, wzr\n"
+        "dup v24.4s, wzr\n"
+        "dup v25.4s, wzr\n"
+        "dup v26.4s, wzr\n"
+        "dup v27.4s, wzr\n"
+        "dup v28.4s, wzr\n"
+        "dup v29.4s, wzr\n"
+        "dup v30.4s, wzr\n"
+        "dup v31.4s, wzr\n"
+
+        // Multiply dst_col_stride by 4 == sizeof(int32) to use
+        // it as a byte offset below.
+        "lsl %[dst_col_stride], %[dst_col_stride], #2\n"
+
+        // Initial arithmetic of the first loop iteration,
+        // taken out of the loop so that in the loop itself we have
+        // optimal streaming of data from memory.
+        "smull    v8.8h,  v0.8b,  v4.8b\n"
+        "smull    v9.8h,  v1.8b,  v4.8b\n"
+        "smull    v10.8h,  v2.8b,  v4.8b\n"
+        "smull    v11.8h,  v3.8b,  v4.8b\n"
+        "smull    v12.8h,  v0.8b,  v5.8b\n"
+        "smull    v13.8h,  v1.8b,  v5.8b\n"
+        "smull    v14.8h,  v2.8b,  v5.8b\n"
+        "smull    v15.8h,  v3.8b,  v5.8b\n"
+
+        // Multiply-accumulate second-half, again into the same
+        // 16bit local accumulator registers. This is where we
+        // take advantage of having int8 instead of uint8 and therefore
+        // being able to accumulate two products into int16.
+        "smlal2   v8.8h,  v0.16b,  v4.16b\n"
+        "smlal2   v9.8h,  v1.16b,  v4.16b\n"
+        "smlal2   v10.8h,  v2.16b,  v4.16b\n"
+        "smlal2   v11.8h,  v3.16b,  v4.16b\n"
+        "smlal2   v12.8h,  v0.16b,  v5.16b\n"
+        "smlal2   v13.8h,  v1.16b,  v5.16b\n"
+        "smlal2   v14.8h,  v2.16b,  v5.16b\n"
+        "smlal2   v15.8h,  v3.16b,  v5.16b\n"
+
+        "subs %[run_depth], %[run_depth], #16\n"
+
+        // If the loop depth is only 16, then we can skip the general loop
+        // and go straight to the final part of the code.
+        "beq " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n"
+
+        // General loop.
+        GEMMLOWP_LABEL_LOOP
+        ":\n"
+
+        // Overview of register layout:
+        //
+        // A 4x16 block of Rhs is stored in 8 bit in v0--v3.
+        // A 4x16 block of Lhs is stored in 8 bit in v4--v7.
+        //
+        // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit
+        // components which need to be horizontally-added at the end)
+        //
+        // The Lhs vectors are multiplied by the Rhs vectors with a widening
+        // multiply over the 8 first levels of depth, producing int16x8
+        // vectors of products for each position in the accumulator matrix.
+        // Here comes the special trick: since the operands are signed int8,
+        // their range being [ -2^7 , 2^7 ), their products are in range
+        // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values
+        // without any risk of overflowing int16.
+        // We thus proceed with the 8 next levels of depth, multiplying
+        // again Lhs by Rhs, accumulating into this existing int16x8 vector.
+        //
+        // Only then, having processed 16 levels of depth, do we need to
+        // horizontally add these int16x8 accumulators into the final
+        // int32x4 accumulators.
+        //
+        // As we do not have enough registers to store all 16 int16x8
+        // temporary-16bit-accumulators, we have them cycle through v8--v15.
+        //
+        //
+        // Register layout (ignoring the v8--v15 temporary 16bit accumulators):
+        //
+        //                               +--------+--------+--------+--------+
+        //                               |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] |
+        //                          Rhs  +--------+--------+--------+--------+
+        //                               |  ...   |  ...   |  ...   |  ...   |
+        //                               +--------+--------+--------+--------|
+        //                               |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]|
+        //                               +--------+--------+--------+--------+
+        //
+        //                               |        |        |        |        |
+        //
+        //    Lhs                        |        |        |        |        |
+        //
+        //  +-------+-----+--------+ - - +--------+--------+--------+--------+
+        //  |v4.b[0]| ... |v4.b[15]|     | v16.4s | v17.4s | v18.4s | v19.4s |
+        //  |v5.b[0]| ... |v5.b[15]|     | v20.4s | v21.4s | v22.4s | v23.4s |
+        //  |v6.b[0]| ... |v6.b[15]|     | v24.4s | v25.4s | v26.4s | v27.4s |
+        //  |v7.b[0]| ... |v7.b[15]|     | v28.4s | v29.4s | v30.4s | v31.4s |
+        //  +-------+--------------+ - - +--------+--------+--------+--------+
+        //
+        //                                                Accumulator
+        //
+
+        // Some multiplications and 16-bit accumulation were already done above,
+        // so we start right away in the middle.
+        "sadalp  v16.4s, v8.8h\n"
+        "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+        "smull    v8.8h,  v0.8b,  v6.8b\n"
+        "sadalp  v17.4s, v9.8h\n"
+        "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+        "smull    v9.8h,  v1.8b,  v6.8b\n"
+        "sadalp  v18.4s, v10.8h\n"
+        "smull    v10.8h,  v2.8b,  v6.8b\n"
+        "sadalp  v19.4s, v11.8h\n"
+        "smull    v11.8h,  v3.8b,  v6.8b\n"
+        "sadalp  v20.4s, v12.8h\n"
+        "smull    v12.8h,  v0.8b,  v7.8b\n"
+        "sadalp  v21.4s, v13.8h\n"
+        "smull    v13.8h,  v1.8b,  v7.8b\n"
+        "sadalp  v22.4s, v14.8h\n"
+        "smull    v14.8h,  v2.8b,  v7.8b\n"
+        "sadalp  v23.4s, v15.8h\n"
+        "smull    v15.8h,  v3.8b,  v7.8b\n"
+
+        // Multiply-accumulate second-half, again into the same
+        // 16bit local accumulator registers. This is where we
+        // take advantage of having int8 instead of uint8 and therefore
+        // being able to accumulate two products into int16.
+        "smlal2   v8.8h,  v0.16b,  v6.16b\n"
+        "smlal2   v9.8h,  v1.16b,  v6.16b\n"
+        "smlal2   v10.8h,  v2.16b,  v6.16b\n"
+        "smlal2   v11.8h,  v3.16b,  v6.16b\n"
+
+        "ld1 {v6.16b}, [%[lhs_ptr]], #16\n"
+
+        "smlal2   v12.8h,  v0.16b,  v7.16b\n"
+        "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+        "smlal2   v13.8h,  v1.16b,  v7.16b\n"
+        "ld1 {v1.16b}, [%[rhs_ptr]], #16\n"
+        "smlal2   v14.8h,  v2.16b,  v7.16b\n"
+        "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+        "smlal2   v15.8h,  v3.16b,  v7.16b\n"
+        "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+        "sadalp  v24.4s, v8.8h\n"
+        "smull    v8.8h,  v0.8b,  v4.8b\n"
+        "sadalp  v25.4s, v9.8h\n"
+        "ld1 {v7.16b}, [%[lhs_ptr]], #16\n"
+        "smull    v9.8h,  v1.8b,  v4.8b\n"
+        "sadalp  v26.4s, v10.8h\n"
+        "smull    v10.8h,  v2.8b,  v4.8b\n"
+        "sadalp  v27.4s, v11.8h\n"
+        "smull    v11.8h,  v3.8b,  v4.8b\n"
+        "sadalp  v28.4s, v12.8h\n"
+        "smull    v12.8h,  v0.8b,  v5.8b\n"
+        "sadalp  v29.4s, v13.8h\n"
+        "smull    v13.8h,  v1.8b,  v5.8b\n"
+        "sadalp  v30.4s, v14.8h\n"
+        "smull    v14.8h,  v2.8b,  v5.8b\n"
+        "sadalp  v31.4s, v15.8h\n"
+        "smull    v15.8h,  v3.8b,  v5.8b\n"
+
+        // Multiply-accumulate second-half, again into the same
+        // 16bit local accumulator registers. This is where we
+        // take advantage of having int8 instead of uint8 and therefore
+        // being able to accumulate two products into int16.
+        "smlal2   v8.8h,  v0.16b,  v4.16b\n"
+        "smlal2   v9.8h,  v1.16b,  v4.16b\n"
+        "smlal2   v10.8h,  v2.16b,  v4.16b\n"
+        "smlal2   v11.8h,  v3.16b,  v4.16b\n"
+
+        // Loop. Decrement loop index (depth) by 16, since we just handled
+        // 16 levels of depth.  Do this subs a bit before the end of the loop
+        // for better dispatch on A57.
+        "subs %[run_depth], %[run_depth], #16\n"
+
+        "smlal2   v12.8h,  v0.16b,  v5.16b\n"
+        "smlal2   v13.8h,  v1.16b,  v5.16b\n"
+        "smlal2   v14.8h,  v2.16b,  v5.16b\n"
+        "smlal2   v15.8h,  v3.16b,  v5.16b\n"
+
+        "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+        // Final code for the last 16 levels of depth.
+        // There is nothing to load anymore, only some arithmetic to finish.
+        GEMMLOWP_LABEL_AFTER_LOOP_LAST16
+        ":\n"
+
+        // Some multiplications and 16-bit accumulation were already done above,
+        // so we start right away in the middle.
+        "sadalp  v16.4s, v8.8h\n"
+        "smull    v8.8h,  v0.8b,  v6.8b\n"
+        "sadalp  v17.4s, v9.8h\n"
+        "smull    v9.8h,  v1.8b,  v6.8b\n"
+        "sadalp  v18.4s, v10.8h\n"
+        "smull    v10.8h,  v2.8b,  v6.8b\n"
+        "sadalp  v19.4s, v11.8h\n"
+        "smull    v11.8h,  v3.8b,  v6.8b\n"
+        "sadalp  v20.4s, v12.8h\n"
+        "smull    v12.8h,  v0.8b,  v7.8b\n"
+        "sadalp  v21.4s, v13.8h\n"
+        "smull    v13.8h,  v1.8b,  v7.8b\n"
+        "sadalp  v22.4s, v14.8h\n"
+        "smull    v14.8h,  v2.8b,  v7.8b\n"
+        "sadalp  v23.4s, v15.8h\n"
+        "smull    v15.8h,  v3.8b,  v7.8b\n"
+
+        // Multiply-accumulate second-half, again into the same
+        // 16bit local accumulator registers. This is where we
+        // take advantage of having int8 instead of uint8 and therefore
+        // being able to accumulate two products into int16.
+        "smlal2   v8.8h,  v0.16b,  v6.16b\n"
+        "smlal2   v9.8h,  v1.16b,  v6.16b\n"
+        "smlal2   v10.8h,  v2.16b,  v6.16b\n"
+        "smlal2   v11.8h,  v3.16b,  v6.16b\n"
+        "smlal2   v12.8h,  v0.16b,  v7.16b\n"
+        "smlal2   v13.8h,  v1.16b,  v7.16b\n"
+        "smlal2   v14.8h,  v2.16b,  v7.16b\n"
+        "smlal2   v15.8h,  v3.16b,  v7.16b\n"
+
+        "sadalp  v24.4s, v8.8h\n"
+        "sadalp  v25.4s, v9.8h\n"
+        "sadalp  v26.4s, v10.8h\n"
+        "sadalp  v27.4s, v11.8h\n"
+        "sadalp  v28.4s, v12.8h\n"
+        "sadalp  v29.4s, v13.8h\n"
+        "sadalp  v30.4s, v14.8h\n"
+        "sadalp  v31.4s, v15.8h\n"
+
+        // Reduce 32bit accumulators horizontally.
+        "addp v0.4s, v16.4s, v20.4s\n"
+        "addp v2.4s, v17.4s, v21.4s\n"
+        "addp v4.4s, v18.4s, v22.4s\n"
+        "addp v6.4s, v19.4s, v23.4s\n"
+        "addp v1.4s, v24.4s, v28.4s\n"
+        "addp v3.4s, v25.4s, v29.4s\n"
+        "addp v5.4s, v26.4s, v30.4s\n"
+        "addp v7.4s, v27.4s, v31.4s\n"
+
+        "cmp %[start_depth], #0\n"
+        "bne " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+        "f\n"
+
+        // Reduce 32bit accumulators horizontally, second pass
+        // (each pass adds pairwise. we need to add 4-wise).
+        "addp v12.4s, v0.4s, v1.4s\n"
+        "addp v13.4s, v2.4s, v3.4s\n"
+        "addp v14.4s, v4.4s, v5.4s\n"
+        "addp v15.4s, v6.4s, v7.4s\n"
+
+        "b " GEMMLOWP_LABEL_STORE "f\n"
+
+        GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+        ":\n"
+
+        // Reduce 32bit accumulators horizontally, second pass
+        // (each pass adds pairwise. we need to add 4-wise),
+        // and load destination values from memory.
+        "mov x0, %[dst_ptr]\n"
+        "ld1 {v12.16b}, [x0], %[dst_col_stride]\n"
+        "addp v8.4s, v0.4s, v1.4s\n"
+        "ld1 {v13.16b}, [x0], %[dst_col_stride]\n"
+        "addp v9.4s, v2.4s, v3.4s\n"
+        "ld1 {v14.16b}, [x0], %[dst_col_stride]\n"
+        "addp v10.4s, v4.4s, v5.4s\n"
+        "ld1 {v15.16b}, [x0]\n"
+        "addp v11.4s, v6.4s, v7.4s\n"
+
+        // Add horizontally-reduced accumulators into
+        // the values loaded from memory
+        "add v12.4s, v12.4s, v8.4s\n"
+        "add v13.4s, v13.4s, v9.4s\n"
+        "add v14.4s, v14.4s, v10.4s\n"
+        "add v15.4s, v15.4s, v11.4s\n"
+
+        GEMMLOWP_LABEL_STORE
+        ":\n"
+        // Store back into memory
+        "mov x0, %[dst_ptr]\n"
+        "st1 {v12.16b}, [x0], %[dst_col_stride]\n"
+        "st1 {v13.16b}, [x0], %[dst_col_stride]\n"
+        "st1 {v14.16b}, [x0], %[dst_col_stride]\n"
+        "st1 {v15.16b}, [x0]\n"
+        :  // outputs
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+        [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth),
+        [dst_col_stride] "+r"(dst_col_stride)
+        :  // inputs
+        [start_depth] "r"(start_depth)
+        :  // clobbers
+        "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
+        "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
+        "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
+        "v28", "v29", "v30", "v31");
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16
+#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+#undef GEMMLOWP_LABEL_STORE
+  }
+};
+
+
 // Our main GEMM kernel.
 struct NEON_64_Kernel12x8Depth2 : KernelBase {
   typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
@@ -658,13 +1277,81 @@
            std::size_t run_depth) const override {
     ScopedProfilingLabel label("optimized kernel (NEON 12x8)");
 // See comments above for why we need local numerical labels in our asm.
-#define GEMMLOWP_LOOP_NEON_64_KERNEL_12X8_DEPTH2 "1"
-#define GEMMLOWP_STORE_RESULT_NEON_64_KERNEL_12x8_DEPTH2 "2"
+#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
+#define GEMMLOWP_LABEL_BEFORE_LOOP "2"
+#define GEMMLOWP_LABEL_LOOP "3"
+#define GEMMLOWP_LABEL_AFTER_LOOP "4"
 
     assert(dst_row_stride == 1);
     asm volatile(
+        // Load 1 Rhs cell of size 2x8
+        "ld1 {v5.8b}, [%[rhs_ptr]], #8\n"
+        "ld1 {v6.8b}, [%[rhs_ptr]], #8\n"
+
+        // Load 3 Lhs cells of size 4x2 each
+        "ld1 {v2.8b}, [%[lhs_ptr]], #8\n"
+        "ld1 {v3.8b}, [%[lhs_ptr]], #8\n"
+        "ld1 {v4.8b}, [%[lhs_ptr]], #8\n"
+
+        // Multiply dst_col_stride by 4 == sizeof(int32) to use
+        // it as a byte offset below.
+        "lsl %[dst_col_stride], %[dst_col_stride], #2\n"
+
+        "cmp %[start_depth], #0\n"
+        "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+        "f\n"
+
+        // Load accumulators
+        "mov x1, %[dst_ptr]\n"
+        "mov x0, x1\n"
+        "ld1 {v8.16b}, [x0], #16\n"
+        "subs %[run_depth], %[run_depth], #2\n"
+        "ld1 {v16.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v24.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v9.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v17.16b}, [x0], #16\n"
+        "ld1 {v25.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v10.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v18.16b}, [x0], #16\n"
+        "ld1 {v26.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v11.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v19.16b}, [x0], #16\n"
+        "ld1 {v27.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v12.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v20.16b}, [x0], #16\n"
+        "ld1 {v28.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v13.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v21.16b}, [x0], #16\n"
+        "ld1 {v29.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v14.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "ld1 {v22.16b}, [x0], #16\n"
+        "ld1 {v30.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "ld1 {v15.16b}, [x0], #16\n"
+        "ld1 {v23.16b}, [x0], #16\n"
+        "ld1 {v31.16b}, [x0]\n"
+
+        "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
+
+        GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+        ":\n"
+
         // Clear accumulator registers (see layout below)
         "dup v8.4s, wzr\n"
+        "subs %[run_depth], %[run_depth], #2\n"
         "dup v9.4s, wzr\n"
         "dup v10.4s, wzr\n"
         "dup v11.4s, wzr\n"
@@ -689,9 +1376,12 @@
         "dup v30.4s, wzr\n"
         "dup v31.4s, wzr\n"
 
-        /* Main loop */
+        GEMMLOWP_LABEL_BEFORE_LOOP
+        ":\n"
 
-        GEMMLOWP_LOOP_NEON_64_KERNEL_12X8_DEPTH2
+        "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
+
+        GEMMLOWP_LABEL_LOOP
         ":\n"
 
         // Overview of register layout:
@@ -729,18 +1419,82 @@
         //
         //                            Accumulator
 
-        // Load 1 Rhs cell of size 2x8
-        "ld1 {v0.8b}, [%[rhs_ptr]], #8\n"
-        "ld1 {v1.8b}, [%[rhs_ptr]], #8\n"
+        // Expand Lhs/Rhs cells to 16 bit.
+        "uxtl v0.8h, v5.8b\n"
+        "ld1 {v5.8b}, [%[rhs_ptr]], #8\n"
+        "uxtl v1.8h, v6.8b\n"
+        "ld1 {v6.8b}, [%[rhs_ptr]], #8\n"
+        "uxtl v2.8h, v2.8b\n"
+        "uxtl v3.8h, v3.8b\n"
+        "uxtl v4.8h, v4.8b\n"
 
-        // Load 3 Lhs cells of size 4x2 each
+        // Multiply-accumulate, top third
+        "umlal v8.4s, v2.4h, v0.h[0]\n"
+        "umlal v9.4s, v2.4h, v0.h[1]\n"
+        "umlal v10.4s, v2.4h, v0.h[2]\n"
+        "umlal v11.4s, v2.4h, v0.h[3]\n"
+        "umlal v12.4s, v2.4h, v1.h[0]\n"
+        "umlal v13.4s, v2.4h, v1.h[1]\n"
+        "umlal v14.4s, v2.4h, v1.h[2]\n"
+        "umlal v15.4s, v2.4h, v1.h[3]\n"
+        "umlal2 v8.4s, v2.8h, v0.h[4]\n"
+        "umlal2 v9.4s, v2.8h, v0.h[5]\n"
+        "umlal2 v10.4s, v2.8h, v0.h[6]\n"
+        "umlal2 v11.4s, v2.8h, v0.h[7]\n"
+        "umlal2 v12.4s, v2.8h, v1.h[4]\n"
+        "umlal2 v13.4s, v2.8h, v1.h[5]\n"
+        "umlal2 v14.4s, v2.8h, v1.h[6]\n"
+        "umlal2 v15.4s, v2.8h, v1.h[7]\n"
         "ld1 {v2.8b}, [%[lhs_ptr]], #8\n"
+
+        // Multiply-accumulate, middle third
+        "umlal v16.4s, v3.4h, v0.h[0]\n"
+        "umlal v17.4s, v3.4h, v0.h[1]\n"
+        "umlal v18.4s, v3.4h, v0.h[2]\n"
+        "umlal v19.4s, v3.4h, v0.h[3]\n"
+        "umlal v20.4s, v3.4h, v1.h[0]\n"
+        "umlal v21.4s, v3.4h, v1.h[1]\n"
+        "umlal v22.4s, v3.4h, v1.h[2]\n"
+        "umlal v23.4s, v3.4h, v1.h[3]\n"
+        "umlal2 v16.4s, v3.8h, v0.h[4]\n"
+        "umlal2 v17.4s, v3.8h, v0.h[5]\n"
+        "umlal2 v18.4s, v3.8h, v0.h[6]\n"
+        "umlal2 v19.4s, v3.8h, v0.h[7]\n"
+        "umlal2 v20.4s, v3.8h, v1.h[4]\n"
+        "umlal2 v21.4s, v3.8h, v1.h[5]\n"
+        "umlal2 v22.4s, v3.8h, v1.h[6]\n"
+        "umlal2 v23.4s, v3.8h, v1.h[7]\n"
         "ld1 {v3.8b}, [%[lhs_ptr]], #8\n"
+
+        "subs %[run_depth], %[run_depth], #2\n"
+
+        // Multiply-accumulate, bottom third
+        "umlal v24.4s, v4.4h, v0.h[0]\n"
+        "umlal v25.4s, v4.4h, v0.h[1]\n"
+        "umlal v26.4s, v4.4h, v0.h[2]\n"
+        "umlal v27.4s, v4.4h, v0.h[3]\n"
+        "umlal v28.4s, v4.4h, v1.h[0]\n"
+        "umlal v29.4s, v4.4h, v1.h[1]\n"
+        "umlal v30.4s, v4.4h, v1.h[2]\n"
+        "umlal v31.4s, v4.4h, v1.h[3]\n"
+        "umlal2 v24.4s, v4.8h, v0.h[4]\n"
+        "umlal2 v25.4s, v4.8h, v0.h[5]\n"
+        "umlal2 v26.4s, v4.8h, v0.h[6]\n"
+        "umlal2 v27.4s, v4.8h, v0.h[7]\n"
+        "umlal2 v28.4s, v4.8h, v1.h[4]\n"
+        "umlal2 v29.4s, v4.8h, v1.h[5]\n"
+        "umlal2 v30.4s, v4.8h, v1.h[6]\n"
+        "umlal2 v31.4s, v4.8h, v1.h[7]\n"
         "ld1 {v4.8b}, [%[lhs_ptr]], #8\n"
 
+        "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+        GEMMLOWP_LABEL_AFTER_LOOP
+        ":\n"
+
         // Expand Lhs/Rhs cells to 16 bit.
-        "uxtl v0.8h, v0.8b\n"
-        "uxtl v1.8h, v1.8b\n"
+        "uxtl v0.8h, v5.8b\n"
+        "uxtl v1.8h, v6.8b\n"
         "uxtl v2.8h, v2.8b\n"
         "uxtl v3.8h, v3.8b\n"
         "uxtl v4.8h, v4.8b\n"
@@ -797,167 +1551,52 @@
         "umlal2 v30.4s, v4.8h, v1.h[6]\n"
         "umlal2 v31.4s, v4.8h, v1.h[7]\n"
 
-        // Loop. Decrement loop index (depth) by 2, since we just handled 2
-        // levels of depth (Kernel::kDepth=2).
+        // Store accumulators
+        "mov x1, %[dst_ptr]\n"
+        "mov x0, x1\n"
+        "st1 {v8.16b}, [x0], #16\n"
         "subs %[run_depth], %[run_depth], #2\n"
-        "bne " GEMMLOWP_LOOP_NEON_64_KERNEL_12X8_DEPTH2
-        "b\n"
-
-        /* end of main loop */
-
-        /* Accumulate our local accumulator registers into the destination block
-           */
-
-        // Compute stride between consecutive columns, in bytes
-        "mov x0, #4\n"  // multiply by 4 = sizeof(int32)
-        "mul %[dst_col_stride], %[dst_col_stride], x0\n"
-
-        // If start_depth == 0, then there is no preexisting accumulator
-        // to accumulate, so we can simply store our result.
-        "cmp %[start_depth], #0\n"
-        "beq " GEMMLOWP_STORE_RESULT_NEON_64_KERNEL_12x8_DEPTH2
-        "f\n"
-
-        "mov x0, %[dst_ptr]\n"
-
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v8.4s, v8.4s, v0.4s\n"
-        "add v16.4s, v16.4s, v1.4s\n"
-        "add v24.4s, v24.4s, v2.4s\n"
-
-        "add x0, x0, %[dst_col_stride]\n"
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v9.4s, v9.4s, v0.4s\n"
-        "add v17.4s, v17.4s, v1.4s\n"
-        "add v25.4s, v25.4s, v2.4s\n"
-
-        "add x0, x0, %[dst_col_stride]\n"
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v10.4s, v10.4s, v0.4s\n"
-        "add v18.4s, v18.4s, v1.4s\n"
-        "add v26.4s, v26.4s, v2.4s\n"
-
-        "add x0, x0, %[dst_col_stride]\n"
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v11.4s, v11.4s, v0.4s\n"
-        "add v19.4s, v19.4s, v1.4s\n"
-        "add v27.4s, v27.4s, v2.4s\n"
-
-        "add x0, x0, %[dst_col_stride]\n"
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v12.4s, v12.4s, v0.4s\n"
-        "add v20.4s, v20.4s, v1.4s\n"
-        "add v28.4s, v28.4s, v2.4s\n"
-
-        "add x0, x0, %[dst_col_stride]\n"
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v13.4s, v13.4s, v0.4s\n"
-        "add v21.4s, v21.4s, v1.4s\n"
-        "add v29.4s, v29.4s, v2.4s\n"
-
-        "add x0, x0, %[dst_col_stride]\n"
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v14.4s, v14.4s, v0.4s\n"
-        "add v22.4s, v22.4s, v1.4s\n"
-        "add v30.4s, v30.4s, v2.4s\n"
-
-        "add x0, x0, %[dst_col_stride]\n"
-        // Load a column
-        "mov x1, x0\n"
-        "ld1 {v0.4s}, [x1], #16\n"
-        "ld1 {v1.4s}, [x1], #16\n"
-        "ld1 {v2.4s}, [x1], #16\n"
-        // Accumulate a column
-        "add v15.4s, v15.4s, v0.4s\n"
-        "add v23.4s, v23.4s, v1.4s\n"
-        "add v31.4s, v31.4s, v2.4s\n"
-
-        GEMMLOWP_STORE_RESULT_NEON_64_KERNEL_12x8_DEPTH2
-        ":\n"
-
-        "mov x0, %[dst_ptr]\n"
-        // Store a column
-        "mov x1, x0\n"
-        "st1 {v8.4s}, [x1], #16\n"
-        "st1 {v16.4s}, [x1], #16\n"
-        "st1 {v24.4s}, [x1], #16\n"
-        // Store a column
-        "add x0, x0, %[dst_col_stride]\n"
-        "mov x1, x0\n"
-        "st1 {v9.4s}, [x1], #16\n"
-        "st1 {v17.4s}, [x1], #16\n"
-        "st1 {v25.4s}, [x1], #16\n"
-        // Store a column
-        "add x0, x0, %[dst_col_stride]\n"
-        "mov x1, x0\n"
-        "st1 {v10.4s}, [x1], #16\n"
-        "st1 {v18.4s}, [x1], #16\n"
-        "st1 {v26.4s}, [x1], #16\n"
-        // Store a column
-        "add x0, x0, %[dst_col_stride]\n"
-        "mov x1, x0\n"
-        "st1 {v11.4s}, [x1], #16\n"
-        "st1 {v19.4s}, [x1], #16\n"
-        "st1 {v27.4s}, [x1], #16\n"
-        // Store a column
-        "add x0, x0, %[dst_col_stride]\n"
-        "mov x1, x0\n"
-        "st1 {v12.4s}, [x1], #16\n"
-        "st1 {v20.4s}, [x1], #16\n"
-        "st1 {v28.4s}, [x1], #16\n"
-        // Store a column
-        "add x0, x0, %[dst_col_stride]\n"
-        "mov x1, x0\n"
-        "st1 {v13.4s}, [x1], #16\n"
-        "st1 {v21.4s}, [x1], #16\n"
-        "st1 {v29.4s}, [x1], #16\n"
-        // Store a column
-        "add x0, x0, %[dst_col_stride]\n"
-        "mov x1, x0\n"
-        "st1 {v14.4s}, [x1], #16\n"
-        "st1 {v22.4s}, [x1], #16\n"
-        "st1 {v30.4s}, [x1], #16\n"
-        // Store a column
-        "add x0, x0, %[dst_col_stride]\n"
-        "mov x1, x0\n"
-        "st1 {v15.4s}, [x1], #16\n"
-        "st1 {v23.4s}, [x1], #16\n"
-        "st1 {v31.4s}, [x1], #16\n"
+        "st1 {v16.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v24.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v9.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v17.16b}, [x0], #16\n"
+        "st1 {v25.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v10.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v18.16b}, [x0], #16\n"
+        "st1 {v26.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v11.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v19.16b}, [x0], #16\n"
+        "st1 {v27.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v12.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v20.16b}, [x0], #16\n"
+        "st1 {v28.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v13.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v21.16b}, [x0], #16\n"
+        "st1 {v29.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v14.16b}, [x0], #16\n"
+        "add x1, x1, %[dst_col_stride]\n"
+        "st1 {v22.16b}, [x0], #16\n"
+        "st1 {v30.16b}, [x0]\n"
+        "mov x0, x1\n"
+        "st1 {v15.16b}, [x0], #16\n"
+        "st1 {v23.16b}, [x0], #16\n"
+        "st1 {v31.16b}, [x0]\n"
+#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+#undef GEMMLOWP_LABEL_BEFORE_LOOP
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP
         :  // outputs
         [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
         [dst_ptr] "+r"(dst_ptr),
@@ -970,78 +1609,11 @@
         "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
         "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
         "v27", "v28", "v29", "v30", "v31");
-#undef GEMMLOWP_LOOP_NEON_64_KERNEL_12X8_DEPTH2
-#undef GEMMLOWP_STORE_RESULT_NEON_64_KERNEL_12x8_DEPTH2
   }
 };
 
 #endif  // GEMMLOWP_NEON_64
 
-// Our main GEMV kernel.
-// Because our GEMV performance is low and not dominated by the kernel
-// at the moment, it's not worth optimizing too hard yet.
-// Using intrinsics allows us to write one implementation for both 32bit and
-// 64bit ARM, and should also perform OK here because the register pressure
-// is not so high in this GEMV kernel.
-// When/if we get serious about GEMV performance, we will want to
-// implement it to bypass packing altogether, and use source data in-place
-// with different GEMV kernels for row-major and column-major LHS.
-template <int Cells>
-struct NEONKernel4Nx1Depth2 : KernelBase {
-  typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, Cells>,
-                       KernelSideFormat<CellFormat<1, 2>, 1> >
-      Format;
-
-  const char* Name() const override { return "NEON intrinsics, 4Nx1, depth 2"; }
-
-  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
-           std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
-           const std::uint8_t* rhs_ptr, std::size_t start_depth,
-           std::size_t run_depth) const override {
-    ScopedProfilingLabel label("optimized kernel (NEON 4Nx1)");
-
-    assert(dst_row_stride == 1);
-
-    // Clear accumulators
-    uint32x4_t acc[Cells];
-    for (int cell = 0; cell < Cells; cell++) {
-      acc[cell] = vdupq_n_u32(0);
-    }
-    // Main loop
-    for (std::size_t d = 0; d < run_depth; d += 2) {
-      // Load LHS cells
-      uint16x8_t lhs[Cells];
-      for (int cell = 0; cell < Cells; cell++) {
-        lhs[cell] = vmovl_u8(vld1_u8(lhs_ptr));
-        lhs_ptr += 8;
-      }
-      // Load RHS cell
-      uint16_t rhs0 = rhs_ptr[0];
-      uint16_t rhs1 = rhs_ptr[1];
-      rhs_ptr += 2;
-      // Multiply-accumulate, level of depth 0
-      for (int cell = 0; cell < Cells; cell++) {
-        acc[cell] = vmlal_n_u16(acc[cell], vget_low_u16(lhs[cell]), rhs0);
-      }
-      // Multiply-accumulate, level of depth 1
-      for (int cell = 0; cell < Cells; cell++) {
-        acc[cell] = vmlal_n_u16(acc[cell], vget_high_u16(lhs[cell]), rhs1);
-      }
-    }
-    // If start_depth is nonzero, accumulate with the existing accumulator
-    if (start_depth) {
-      for (int cell = 0; cell < Cells; cell++) {
-        acc[cell] = vaddq_u32(
-            acc[cell], vreinterpretq_u32_s32(vld1q_s32(dst_ptr + 4 * cell)));
-      }
-    }
-    // Store the accumulators
-    for (int cell = 0; cell < Cells; cell++) {
-      vst1q_s32(dst_ptr + 4 * cell, vreinterpretq_s32_u32(acc[cell]));
-    }
-  }
-};
-
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_KERNEL_NEON_H_
diff --git a/internal/kernel_reference.h b/internal/kernel_reference.h
index 020b479..3458c6a 100644
--- a/internal/kernel_reference.h
+++ b/internal/kernel_reference.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -59,15 +59,13 @@
       // The next two loops are over cells of the Lhs (stacked vertically),
       // and over cells of the Rhs (stacked horizontally).
       for (int rc = 0; rc < Format::Lhs::kCells; rc++) {
-        const std::uint8_t* lhs_cell_ptr = lhs_ptr +
-                                           (dc * Format::Lhs::kCells + rc) *
-                                               Format::Lhs::Cell::kWidth *
-                                               Format::kDepth;
+        const std::uint8_t* lhs_cell_ptr =
+            lhs_ptr + (dc * Format::Lhs::kCells + rc) *
+                          Format::Lhs::Cell::kWidth * Format::kDepth;
         for (int cc = 0; cc < Format::Rhs::kCells; cc++) {
-          const std::uint8_t* rhs_cell_ptr = rhs_ptr +
-                                             (dc * Format::Rhs::kCells + cc) *
-                                                 Format::Rhs::Cell::kWidth *
-                                                 Format::kDepth;
+          const std::uint8_t* rhs_cell_ptr =
+              rhs_ptr + (dc * Format::Rhs::kCells + cc) *
+                            Format::Rhs::Cell::kWidth * Format::kDepth;
 
           // Now we are inside one cell of the Lhs and inside one cell
           // of the Rhs, so the remaining inner loops are just
diff --git a/internal/kernel_sse.h b/internal/kernel_sse.h
new file mode 100644
index 0000000..b879fd7
--- /dev/null
+++ b/internal/kernel_sse.h
@@ -0,0 +1,517 @@
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// kernel_SSE.h: a collection of Intel SSE optimized kernels.
+// Check in kernel_default.h which one(s) are actually used by default.
+// Others are mere experiments; they are still covered by tests
+// in case they might be useful some day.
+//
+
+#ifndef GEMMLOWP_INTERNAL_KERNEL_SSE_H_
+#define GEMMLOWP_INTERNAL_KERNEL_SSE_H_
+
+#include "kernel.h"
+
+#include <string.h>
+#include <cassert>
+
+namespace gemmlowp {
+
+#ifdef GEMMLOWP_SSE4_32
+struct SSE4_32_Kernel4x4Depth2 : KernelBase {
+  typedef KernelFormat<
+      KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1>,
+      KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1> >
+      Format;
+
+  const char* Name() const override { return "SSE, 4x4, depth 2"; }
+
+  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+           std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+           const std::uint8_t* rhs_ptr, std::size_t start_depth,
+           std::size_t run_depth) const override {
+    ScopedProfilingLabel label("optimized kernel");
+    assert(dst_row_stride == 1);
+    std::int32_t run_depth_cells = run_depth / Format::kDepth;
+    /* Main loop */
+
+    // A 2x4 cell of Rhs is stored in 16bit in xmm1 .
+    // A 4x2 block Lhs is stored in 16bit in xmm0.
+    // A 4x4 block of accumulators is stored in 32bit in xmm4--xmm7.
+    //
+    //                   +-------+-------+-------+-------+
+    //                   |xmm1[0]|xmm1[2]|xmm1[4]|xmm1[6]|
+    //              Rhs  +-------+---------------+-------+
+    //                   |xmm1[1]|xmm1[3]|xmm1[5]|xmm1[7]|
+    //                   +-------+-------+-------+-------+
+    //
+    //                   |       |       |       |       |
+    //
+    //    Lhs            |       |       |       |       |
+    //
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //  |xmm0 |          | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  |xmm0 | (Iter1)  | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  |xmm0 |          | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  |xmm0 |          | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //
+    //                              Accumulator
+
+    asm volatile(
+
+        // set accumulators to zero.
+        "pxor %%xmm4  , %%xmm4 \n\t"
+        "pxor %%xmm5  , %%xmm5 \n\t"
+        "pxor %%xmm6  , %%xmm6 \n\t"
+        "pxor %%xmm7  , %%xmm7 \n\t"
+
+        "movl  %[run_depth_cells], %%eax\n\t"
+        "subl $2, %%eax\n\t"
+        "js outerLoop1%=\n\t"
+
+        // Loop for K unrolled by 4
+        "outerLoop2%=:\n\t"
+
+        // K = 1,2
+        // RHS cell to xmm1
+        "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t"
+
+        // LHS cell
+        "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm4           \n\t"
+        "paddd %%xmm3, %%xmm5           \n\t"
+
+        "prefetcht0 0x80(%[lhs_ptr]) \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+
+        "prefetcht0 0x80(%[rhs_ptr]) \n\t"
+
+        // K = 3,4
+        // RHS cell to xmm1
+        "pmovzxbw 0x08(%[rhs_ptr]), %%xmm1\n\t"
+
+        "paddd %%xmm2, %%xmm6           \n\t"
+        "paddd %%xmm3, %%xmm7           \n\t"
+
+        // LHS cell
+        "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm4           \n\t"
+        "paddd %%xmm3, %%xmm5           \n\t"
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+
+        "addl $0x10, %[lhs_ptr]         \n\t"
+        "addl $0x10, %[rhs_ptr]         \n\t"
+
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm3, %%xmm7           \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "paddd %%xmm2, %%xmm6           \n\t"
+
+        "subl $2, %[run_depth_cells]\n\t"
+        "ja outerLoop2%=\n\t"
+
+        "movl %[run_depth_cells], %%eax\n\t"
+        "decl %%eax\n\t"
+        "js finish%=\n\t"
+
+        // Loop for K unrolled by 2
+        "outerLoop1%=:\n\t"
+
+        // RHS cell to xmm1
+        "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t"
+
+        // LHS cell
+        "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "paddd %%xmm2, %%xmm4           \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm3, %%xmm5           \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "paddd %%xmm2, %%xmm6           \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm3, %%xmm7           \n\t"
+
+        "addl $0x08, %[lhs_ptr]\n\t"
+        "addl $0x08, %[rhs_ptr]\n\t"
+
+        "decl %[run_depth_cells]\n\t"
+        "jnz outerLoop1%=\n\t"
+
+        "finish%=:\n\t"
+
+        "movl  %[dst_col_stride], %%eax\n\t"
+        "shll $2, %%eax\n\t"
+
+        "movl  %[start_depth], %%ecx\n\t"
+        "test %%ecx, %%ecx\n\t"
+        "jz storeDst%=\n\t"
+
+        "leal (%%eax,%%eax,0x2), %%ecx\n\t"
+        "paddd 0x00(%[dst_ptr])           , %%xmm4 \n\t"
+        "paddd 0x00(%[dst_ptr], %%eax, 1) , %%xmm5 \n\t"
+        "paddd 0x00(%[dst_ptr], %%eax, 2) , %%xmm6 \n\t"
+        "paddd 0x00(%[dst_ptr], %%ecx, 1) , %%xmm7 \n\t"
+
+        "storeDst%=:\n\t"
+
+        "leal (%%eax,%%eax,0x2), %%ecx\n\t"
+        "movdqu %%xmm4  , 0x00(%[dst_ptr])          \n\t"
+        "movdqu %%xmm5  , 0x00(%[dst_ptr], %%eax, 1)\n\t"
+        "movdqu %%xmm6  , 0x00(%[dst_ptr], %%eax, 2)\n\t"
+        "movdqu %%xmm7  , 0x00(%[dst_ptr], %%ecx, 1)\n\t"
+
+        :  // outputs
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+        [dst_ptr] "+r"(dst_ptr)
+        :  // inputs
+        [start_depth] "g"(start_depth), [dst_col_stride] "g"(dst_col_stride),
+        [run_depth_cells] "g"(run_depth_cells)
+        :  // clobbers
+        "cc", "memory", "%xmm0", "%xmm1", "%xmm3", "%xmm2", "%xmm4", "%xmm5",
+        "%xmm6", "%xmm7", "%eax", "%ecx");
+  }
+};
+#endif
+#ifdef GEMMLOWP_SSE4_64
+struct SSE4_64_Kernel12x4Depth2 : KernelBase {
+  typedef KernelFormat<
+      KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>,
+      KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1> >
+      Format;
+
+  const char* Name() const override { return "SSE, 12x4, depth 2"; }
+
+  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+           std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+           const std::uint8_t* rhs_ptr, std::size_t start_depth,
+           std::size_t run_depth) const override {
+    ScopedProfilingLabel label("optimized kernel");
+    assert(dst_row_stride == 1);
+    const std::int64_t run_depth_cells = run_depth / Format::kDepth;
+    const std::int64_t dst_col_stride_q = dst_col_stride;
+
+    /* Main loop */
+
+    // A 2x4 cell of Rhs is stored in 16bit in xmm1 .
+    // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in xmm0, replaced
+    // every Iteration.
+    // A 12x4 block of accumulators is stored in 32bit in xmm4--xmm15.
+    //
+    //                   +-------+-------+-------+-------+
+    //                   |xmm1[0]|xmm1[2]|xmm1[4]|xmm1[6]|
+    //              Rhs  +-------+---------------+-------+
+    //                   |xmm1[1]|xmm1[3]|xmm1[5]|xmm1[7]|
+    //                   +-------+-------+-------+-------+
+    //
+    //                   |       |       |       |       |
+    //
+    //    Lhs            |       |       |       |       |
+    //
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //  |xmm0 |          | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  |xmm0 | (Iter1)  | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  |xmm0 |          | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  |xmm0 |          | xmm4  | xmm5  | xmm6  | xmm7  |
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //  |xmm0 |          | xmm8  | xmm9  | xmm10 | xmm11 |
+    //  |xmm0 | (Iter2)  | xmm8  | xmm9  | xmm10 | xmm11 |
+    //  |xmm0 |          | xmm8  | xmm9  | xmm10 | xmm11 |
+    //  |xmm0 |          | xmm8  | xmm9  | xmm10 | xmm11 |
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //  |xmm0 |          | xmm12 | xmm13 | xmm14 | xmm15 |
+    //  |xmm0 | (Iter3)  | xmm12 | xmm13 | xmm14 | xmm15 |
+    //  |xmm0 |          | xmm12 | xmm13 | xmm14 | xmm15 |
+    //  |xmm0 |          | xmm12 | xmm13 | xmm14 | xmm15 |
+    //  +--+--+ - - - -  +-------+-------+-------+-------+
+    //
+    //                              Accumulator
+
+    asm volatile(
+
+        // Set registers for destination
+        "movq  %[dst_col_stride_q], %%r12\n\t"
+        "shlq $2, %%r12\n\t"
+        "leaq (%%r12,%%r12,0x2), %%r13\n\t"
+
+        // Set accumulators to zero.
+        "pxor %%xmm4  , %%xmm4 \n\t"
+        "pxor %%xmm5  , %%xmm5 \n\t"
+        "pxor %%xmm6  , %%xmm6 \n\t"
+        "pxor %%xmm7  , %%xmm7 \n\t"
+        "pxor %%xmm8  , %%xmm8 \n\t"
+        "pxor %%xmm9  , %%xmm9 \n\t"
+        "pxor %%xmm10 , %%xmm10\n\t"
+        "pxor %%xmm11 , %%xmm11\n\t"
+        "pxor %%xmm12 , %%xmm12\n\t"
+        "pxor %%xmm13 , %%xmm13\n\t"
+        "pxor %%xmm14 , %%xmm14\n\t"
+        "pxor %%xmm15 , %%xmm15\n\t"
+
+        "movq  %[run_depth_cells], %%r14\n\t"
+        "subq $2, %%r14\n\t"
+        "js outerLoop1%=\n\t"
+
+        // Loop for K unrolled by 4
+        "outerLoop2%=:\n\t"
+
+        // K = 1,2
+        // RHS cell to xmm1
+
+        "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t"
+
+        // LHS cell
+        "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm4           \n\t"
+        "paddd %%xmm3, %%xmm5           \n\t"
+
+        "prefetcht0 0x80(%[lhs_ptr]) \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+
+        // next LHS cell
+        "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t"
+
+        "paddd %%xmm2, %%xmm6           \n\t"
+        "paddd %%xmm3, %%xmm7           \n\t"
+
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm8           \n\t"
+        "paddd %%xmm3, %%xmm9           \n\t"
+
+        "prefetcht0 0x80(%[rhs_ptr]) \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm10          \n\t"
+        "paddd %%xmm3, %%xmm11          \n\t"
+
+        // next LHS cell
+        "pmovzxbw 0x10(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm12          \n\t"
+        "paddd %%xmm3, %%xmm13          \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm14          \n\t"
+        "paddd %%xmm3, %%xmm15          \n\t"
+
+        // K = 3,4
+        // RHS cell to xmm1
+        "pmovzxbw 0x08(%[rhs_ptr]), %%xmm1\n\t"
+
+        // LHS cell
+        "pmovzxbw 0x18(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm4           \n\t"
+        "paddd %%xmm3, %%xmm5           \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm6           \n\t"
+        "paddd %%xmm3, %%xmm7           \n\t"
+
+        // next LHS cell
+        "pmovzxbw 0x20(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm8           \n\t"
+        "paddd %%xmm3, %%xmm9           \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm10          \n\t"
+        "paddd %%xmm3, %%xmm11          \n\t"
+
+        // next LHS cell
+        "pmovzxbw 0x28(%[lhs_ptr]), %%xmm0\n\t"
+
+        "addq $0x30, %[lhs_ptr]         \n\t"
+        "addq $0x10, %[rhs_ptr]         \n\t"
+
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm12          \n\t"
+        "paddd %%xmm3, %%xmm13          \n\t"
+
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm14          \n\t"
+        "paddd %%xmm3, %%xmm15          \n\t"
+
+        "subq $2, %[run_depth_cells]\n\t"
+        "ja outerLoop2%=\n\t"
+
+        "movq %[run_depth_cells], %%r14\n\t"
+        "decq %%r14\n\t"
+        "js finish%=\n\t"
+
+        // Loop for K unrolled by 2
+        "outerLoop1%=:\n\t"
+
+        // RHS cell to xmm1
+        "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t"
+
+        // LHS cell
+        "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm4           \n\t"
+        "paddd %%xmm3, %%xmm5           \n\t"
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm6           \n\t"
+        "paddd %%xmm3, %%xmm7           \n\t"
+
+        // next LHS cell
+        "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t"
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm8           \n\t"
+        "paddd %%xmm3, %%xmm9           \n\t"
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm10          \n\t"
+        "paddd %%xmm3, %%xmm11          \n\t"
+
+        // next LHS cell
+        "pmovzxbw 0x10(%[lhs_ptr]), %%xmm0\n\t"
+
+        "addq $0x18, %[lhs_ptr]         \n\t"
+        "addq $0x08, %[rhs_ptr]         \n\t"
+
+        "pshufd $0x00,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0x55,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm12          \n\t"
+        "paddd %%xmm3, %%xmm13          \n\t"
+        "pshufd $0xaa,%%xmm1,%%xmm2     \n\t"
+        "pshufd $0xff,%%xmm1,%%xmm3     \n\t"
+        "pmaddwd %%xmm0, %%xmm2         \n\t"
+        "pmaddwd %%xmm0, %%xmm3         \n\t"
+        "paddd %%xmm2, %%xmm14          \n\t"
+        "paddd %%xmm3, %%xmm15          \n\t"
+
+        "decq %[run_depth_cells]\n\t"
+        "jnz outerLoop1%=\n\t"
+
+        "finish%=:\n\t"
+
+        "test %[start_depth], %[start_depth]\n\t"
+        "jz storeDst%=\n\t"
+
+        "paddd 0x00(%[dst_ptr])           , %%xmm4 \n\t"
+        "paddd 0x10(%[dst_ptr])           , %%xmm8 \n\t"
+        "paddd 0x20(%[dst_ptr])           , %%xmm12\n\t"
+        "paddd 0x00(%[dst_ptr], %%r12, 1) , %%xmm5 \n\t"
+        "paddd 0x10(%[dst_ptr], %%r12, 1) , %%xmm9 \n\t"
+        "paddd 0x20(%[dst_ptr], %%r12, 1) , %%xmm13\n\t"
+        "paddd 0x00(%[dst_ptr], %%r12, 2) , %%xmm6 \n\t"
+        "paddd 0x10(%[dst_ptr], %%r12, 2) , %%xmm10\n\t"
+        "paddd 0x20(%[dst_ptr], %%r12, 2) , %%xmm14\n\t"
+        "paddd 0x00(%[dst_ptr], %%r13, 1) , %%xmm7 \n\t"
+        "paddd 0x10(%[dst_ptr], %%r13, 1) , %%xmm11\n\t"
+        "paddd 0x20(%[dst_ptr], %%r13, 1) , %%xmm15\n\t"
+
+        "storeDst%=:\n\t"
+
+        "movdqu %%xmm4  , 0x00(%[dst_ptr])          \n\t"
+        "movdqu %%xmm8  , 0x10(%[dst_ptr])          \n\t"
+        "movdqu %%xmm12 , 0x20(%[dst_ptr])          \n\t"
+        "movdqu %%xmm5  , 0x00(%[dst_ptr], %%r12, 1)\n\t"
+        "movdqu %%xmm9  , 0x10(%[dst_ptr], %%r12, 1)\n\t"
+        "movdqu %%xmm13 , 0x20(%[dst_ptr], %%r12, 1)\n\t"
+        "movdqu %%xmm6  , 0x00(%[dst_ptr], %%r12, 2)\n\t"
+        "movdqu %%xmm10 , 0x10(%[dst_ptr], %%r12, 2)\n\t"
+        "movdqu %%xmm14 , 0x20(%[dst_ptr], %%r12, 2)\n\t"
+        "movdqu %%xmm7  , 0x00(%[dst_ptr], %%r13, 1)\n\t"
+        "movdqu %%xmm11 , 0x10(%[dst_ptr], %%r13, 1)\n\t"
+        "movdqu %%xmm15 , 0x20(%[dst_ptr], %%r13, 1)\n\t"
+
+        :  // outputs
+        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+        [dst_ptr] "+r"(dst_ptr)
+        :  // inputs
+        [start_depth] "r"(start_depth),
+        [dst_col_stride_q] "r"(dst_col_stride_q),
+        [run_depth_cells] "r"(run_depth_cells)
+        :  // clobbers
+        "cc", "memory", "%xmm0", "%xmm1", "%xmm3", "%xmm2", "%xmm4", "%xmm5",
+        "%xmm6", "%xmm7", "%xmm8", "%xmm9", "%xmm10", "%r12", "%r13", "%r14",
+        "%xmm11", "%xmm12", "%xmm13", "%xmm14", "%xmm15");
+  }
+};
+#endif
+
+}  // namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_KERNEL_SSE_H_
diff --git a/internal/multi_thread_gemm.h b/internal/multi_thread_gemm.h
index 0aacddb..0234b26 100644
--- a/internal/multi_thread_gemm.h
+++ b/internal/multi_thread_gemm.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -27,9 +27,15 @@
 
 namespace gemmlowp {
 
-#ifdef GEMMLOWP_ALLOW_INLINE_ASM
-// Where inline asm is allowed, we use some busy-waiting,
-// preferably implemented using NOP instructions.
+// On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a
+// pthread conditional variable. In order to implement that correctly we need
+// to put some explicit memory load/store barriers.
+
+#if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \
+    (defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86))
+
+#define GEMMLOWP_USE_BUSYWAIT
+
 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
 
 #define GEMMLOWP_NOP "nop\n"
@@ -38,11 +44,10 @@
 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
-#define GEMMLOWP_NOP256 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP64)
 
 inline int Do256NOPs() {
-  asm volatile(GEMMLOWP_NOP256);
-  return 256;
+  asm volatile(GEMMLOWP_NOP64);
+  return 64;
 }
 
 #undef GEMMLOWP_STRING_CONCAT_4
@@ -52,20 +57,6 @@
 #undef GEMMLOWP_NOP4
 #undef GEMMLOWP_NOP
 
-#else  // not GEMMLOWP_ALLOW_INLINE_ASM
-
-// It is nontrivial to implement a good busy-waiting without
-// using asm; NOP instructions have the least side effects
-// and the lowest power usage; and since the whole busy-waiting
-// story is an optimization, it's not very interesting anyway
-// in places where we're slow anyway due to not being able to
-// use our inline asm kernels.
-
-const int kMaxBusyWaitNOPs = 0;
-inline int Do256NOPs() { return 0; }
-
-#endif  // not GEMMLOWP_ALLOW_INLINE_ASM
-
 inline void WriteBarrier() {
 #ifdef GEMMLOWP_ARM_32
   MemoryBarrier();
@@ -73,8 +64,6 @@
   asm volatile("dmb ishst" ::: "memory");
 #elif defined(GEMMLOWP_X86)
   asm volatile("sfence" ::: "memory");
-#elif defined(__mips__)
-  MemoryBarrier();
 #else
 #error "Unsupported architecture for WriteBarrier."
 #endif
@@ -87,13 +76,13 @@
   asm volatile("dmb ishld" ::: "memory");
 #elif defined(GEMMLOWP_X86)
   asm volatile("lfence" ::: "memory");
-#elif defined(__mips__)
-  MemoryBarrier();
 #else
 #error "Unsupported architecture for ReadBarrier."
 #endif
 }
 
+#endif
+
 // Waits until *var != initial_value.
 //
 // Returns the new value of *var. The guarantee here is that
@@ -119,23 +108,31 @@
 template <typename T>
 T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
                         pthread_mutex_t* mutex) {
-  int nops = 0;
-  // First, trivial case where the variable already changed value.
-  T new_value = *var;
-  if (new_value != initial_value) {
-    return new_value;
-  }
-  // Then try busy-waiting.
-  while (nops < kMaxBusyWaitNOPs) {
-    nops += Do256NOPs();
-    new_value = *var;
+#ifdef GEMMLOWP_USE_BUSYWAIT
+  // If we are on a platform that supports it, spin for some time.
+  {
+    int nops = 0;
+    // First, trivial case where the variable already changed value.
+    T new_value = *var;
     if (new_value != initial_value) {
+      ReadBarrier();
       return new_value;
     }
+    // Then try busy-waiting.
+    while (nops < kMaxBusyWaitNOPs) {
+      nops += Do256NOPs();
+      new_value = *var;
+      if (new_value != initial_value) {
+        ReadBarrier();
+        return new_value;
+      }
+    }
   }
+#endif
+
   // Finally, do real passive waiting.
   pthread_mutex_lock(mutex);
-  new_value = *var;
+  T new_value = *var;
   if (new_value == initial_value) {
     pthread_cond_wait(cond, mutex);
     new_value = *var;
@@ -174,6 +171,9 @@
     pthread_mutex_lock(&mutex_);
     assert(count_ > 0);
     count_--;
+#ifdef GEMMLOWP_USE_BUSYWAIT
+    WriteBarrier();
+#endif
     if (count_ == 0) {
       pthread_cond_signal(&cond_);
     }
@@ -206,7 +206,7 @@
 struct Task {
   Task() : local_allocator(nullptr) {}
   virtual ~Task() {}
-  virtual void Run() const = 0;
+  virtual void Run() = 0;
   Allocator* local_allocator;
 };
 
@@ -283,10 +283,8 @@
       switch (state_to_act_upon) {
         case State::HasWork:
           // Got work to do! So do it, and then revert to 'Ready' state.
-          ReadBarrier();
           assert(task_);
           task_->Run();
-          delete task_;
           task_ = nullptr;
           ChangeState(State::Ready);
           break;
@@ -309,7 +307,9 @@
     assert(!task_);
     task->local_allocator = &local_allocator_;
     task_ = task;
+#ifdef GEMMLOWP_USE_BUSYWAIT
     WriteBarrier();
+#endif
     assert(state_ == State::Ready);
     ChangeState(State::HasWork);
   }
@@ -319,7 +319,7 @@
   pthread_t thread_;
 
   // The task to be worked on.
-  const Task* task_;
+  Task* task_;
 
   // The condition variable and mutex guarding state changes.
   pthread_cond_t state_cond_;
@@ -341,6 +341,11 @@
 // specific parallelization pattern that we use here:
 // a fixed number of workers can be given work, and one then
 // waits for all of them to finish.
+//
+// See MultiThreadGemmContextBase for how other WorkersPool implementations can
+// be used. Note that in those implementations, StartWorker can be free to
+// ignore the <index> value; that is, the caller of WorkersPool does not rely on
+// <index> to order tasks with equal <index>.
 class WorkersPool {
  public:
   WorkersPool() {}
@@ -351,16 +356,31 @@
     }
   }
 
-  BlockingCounter& counter_to_decrement_when_ready() {
-    return counter_to_decrement_when_ready_;
+  void Execute(const std::vector<Task*>& tasks) {
+    assert(tasks.size() >= 1);
+    // One of the tasks will be run on the current thread.
+    int workers_count = tasks.size() - 1;
+    CreateWorkers(workers_count);
+    assert(workers_count <= workers_.size());
+    counter_to_decrement_when_ready_.Reset(workers_count);
+    int n = 0;
+    std::for_each(tasks.begin(), --tasks.end(), [this, &n](Task *task) {
+      workers_[n++]->StartWork(task);
+    });
+    // Execute the remaining workload immediately on the current thread.
+    Task* task = tasks.back();
+    task->local_allocator = &main_thread_task_allocator_;
+    task->Run();
+    // Wait for the workers submitted above to finish.
+    counter_to_decrement_when_ready_.Wait();
+    // Cleanup tasks (best to do this from the same thread that allocated
+    // the memory).
+    std::for_each(tasks.begin(), tasks.end(), [](Task *task) {
+      delete task;
+    });
   }
 
-  // Give work to a specific worker.
-  void StartWorker(int index, Task* task_) {
-    assert(static_cast<std::size_t>(index) < workers_.size());
-    workers_[index]->StartWork(task_);
-  }
-
+ private:
   // Ensures that the pool has at least the given count of workers.
   // If any new worker has to be created, this function waits for it to
   // be ready.
@@ -375,7 +395,6 @@
     counter_to_decrement_when_ready_.Wait();
   }
 
- private:
   // copy construction disallowed
   WorkersPool(const WorkersPool&) = delete;
 
@@ -385,6 +404,14 @@
 
   // The BlockingCounter used to wait for the workers.
   BlockingCounter counter_to_decrement_when_ready_;
+
+  // For N-threaded operations, we will use only N-1 worker threads
+  // while the last task will be run directly on the main thread.
+  // It will then use this main_thread_task_allocator_; having a
+  // dedicated allocator for that (separate from the base allocator_)
+  // allows to use the same code for all tasks regardless of which
+  // thread they run on.
+  Allocator main_thread_task_allocator_;
 };
 
 // The task we use to implement a multi-threaded Gemm: a block of the
@@ -394,34 +421,41 @@
 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
           typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
           MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
-          typename OutputPipelineType>
+  typename OutputPipelineType, typename GemmContextType>
 struct GemmWithPackedRhsTask : Task {
   typedef PackedSideBlock<typename KernelFormat::Lhs> PackedLhs;
   typedef PackedSideBlock<typename KernelFormat::Rhs> PackedRhs;
-  GemmWithPackedRhsTask(const KernelBase& _kernel,
+  GemmWithPackedRhsTask(GemmContextType* _context,
+                        const KernelBase& _kernel,
                         const MatrixMap<const InputScalar, LhsOrder>& _lhs,
                         const PackedRhs& _packed_rhs,
                         MatrixMap<OutputScalar, ResultOrder>* _result,
+                        const MatrixBlockBounds& _result_block,
                         const LhsOffset& _lhs_offset,
                         const RhsOffset& _rhs_offset,
                         const OutputPipelineType& _output_pipeline)
-      : kernel(_kernel),
+      : context(_context),
+        kernel(_kernel),
         lhs(_lhs),
         packed_rhs(_packed_rhs),
         result(*_result),
+        result_block(_result_block),
         lhs_offset(_lhs_offset),
         rhs_offset(_rhs_offset),
         output_pipeline(_output_pipeline) {}
 
-  void Run() const override {
+  void Run() override {
     ScopedProfilingLabel label("GemmWithPackedRhsTask");
 
-    const int rows = result.rows();
-    const int cols = result.cols();
+    const int rows = result_block.rows;
+    const int cols = result_block.cols;
     const int depth = lhs.cols();
 
     BlockParams block_params;
-    block_params.Init<KernelFormat>(rows, cols, depth, 1);
+    block_params.Init<KernelFormat>(rows, cols, depth, 1,
+                                    context->l1_bytes_to_use(),
+                                    context->l2_bytes_to_use(),
+                                    context->l2_rhs_factor());
 
     PackedLhs packed_lhs(Side::Lhs, local_allocator, block_params);
 
@@ -435,74 +469,92 @@
       for (int r = 0; r < rows; r += block_params.l2_rows) {
         int rs = std::min(block_params.l2_rows, rows - r);
 
-        PackLhs<BitDepthParams>(&packed_lhs, lhs.block(r, 0, rs, depth));
+        PackLhs(&packed_lhs, lhs.block(r, 0, rs, depth));
 
-        Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs);
+        Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs,
+                depth);
 
-        auto result_block = result.block(r, c, rs, cs);
-        UnpackResult<BitDepthParams>(&result_block, packed_result, depth,
-                                     packed_lhs.sums_of_each_slice(),
-                                     packed_rhs.sums_of_each_slice(),
-                                     lhs_offset, rhs_offset, output_pipeline);
+        auto curr_result_block = MatrixBlockBounds(
+            result_block.start_row + r, result_block.start_col + c, rs, cs);
+        UnpackResult<KernelFormat>(
+            &result, curr_result_block, packed_result, depth,
+            packed_lhs.sums_of_each_slice(), packed_rhs.sums_of_each_slice(),
+            lhs_offset.block(curr_result_block.start_row, rs),
+            rhs_offset.block(curr_result_block.start_col, cs), output_pipeline);
       }
     }
 
     local_allocator->Decommit();
   }
 
+  const GemmContextType* context;
   const KernelBase& kernel;
   const MatrixMap<const InputScalar, LhsOrder> lhs;
   const PackedRhs packed_rhs;
   MatrixMap<OutputScalar, ResultOrder> result;
+  const MatrixBlockBounds result_block;
   const LhsOffset& lhs_offset;
   const RhsOffset& rhs_offset;
   const OutputPipelineType& output_pipeline;
 };
 
-class MultiThreadGemmContext : public SingleThreadGemmContext {
+// This base class for multi-threading allows subclasses to implement their own
+// workers_pool() method.  See MultiThreadGemmContext below for an example;
+// any other implementation of workers_pool() must return an object with the
+// same public methods as WorkersPool.
+class MultiThreadGemmContextBase : public SingleThreadGemmContext {
  public:
-  MultiThreadGemmContext() : max_num_threads_(0) {}
-
   void set_max_num_threads(int n) { max_num_threads_ = n; }
 
   int max_num_threads() const { return max_num_threads_; }
 
+ protected:
+  // The maximum number of worker threads to use (including
+  // the master thread).
+  // The default value 1 means single-threading. That is the default
+  // because gemmlowp's primary target is mobile hardware, where thermal
+  // constraints usually mean that it may not be realistic to use more
+  // than 1 CPU core even if multiple cores are present.
+  // The special value 0 means try to detect the number of hardware threads.
+  // Note: this assumes that all CPU cores are equivalent. That assumption
+  // is defeated on big.LITTLE ARM devices, where we have no API to query
+  // the number of big cores (which is typically what we would want to use,
+  // leaving aside above-mentioned thermal issues). That is the other reason
+  // why the best compromise here is to let max_num_threads_ default to 1,
+  // so users who want multi-threading have to make the decision of how many
+  // threads to use by themselves.
+  int max_num_threads_ = 1;
+};
+
+class MultiThreadGemmContext : public MultiThreadGemmContextBase {
+ public:
   WorkersPool* workers_pool() { return &workers_pool_; }
 
-  Allocator* main_thread_task_allocator() {
-    return &main_thread_task_allocator_;
-  }
-
- protected:
+ private:
   // The workers pool used by MultiThreadGemm. Making
   // this part of the context allows it to be persistent,
   // avoiding recreating threads on every Gemm.
   WorkersPool workers_pool_;
-
-  // The maximum number of worker threads to use (in addition
-  // to the master thread).
-  // The default value 0 means the default behavior of
-  // detecting the number of hardware threads. Nonzero values mean
-  // skipping and overriding hardware detection.
-  int max_num_threads_;
-
-  // For N-threaded operations, we will use only N-1 worker threads
-  // while the last task will be run directly on the main thread.
-  // It will then use this main_thread_task_allocator_; having a
-  // dedicated allocator for that (separate from the base allocator_)
-  // allows to use the same code for all tasks regardless of which
-  // thread they run on.
-  Allocator main_thread_task_allocator_;
 };
 
+// Needed by chrome native builds
+#ifndef _SC_NPROCESSORS_CONF
+#define _SC_NPROCESSORS_CONF _SC_NPROCESSORS_ONLN
+#endif
+
 // Determines how many threads should be used for a given Gemm
 // operation.
 template <int KernelRows>
-inline int HowManyThreads(MultiThreadGemmContext* context, int rows, int cols,
-                          int depth) {
-  // First check if the user set an explicit maximum number of threads.
-  int max_count = context->max_num_threads();
-  if (!max_count) {
+inline int HowManyThreads(int max_num_threads, int rows, int cols, int depth) {
+  // Early-exit in the default case where multi-threading is disabled.
+  if (max_num_threads == 1) {
+    return 1;
+  }
+
+  // Determine the maximum number of threads.
+  int max_count = max_num_threads;
+  // The special value 0 means try to determine the total number of cores.
+  if (max_count == 0) {
     // No user-set maximum number of threads, so we need to
     // do some hardware detection.
     // This is expensive to query so we do it only once.
@@ -553,15 +605,15 @@
 }
 
 // The main multi-threaded Gemm function.
-// To understand it, first read the code of SingleThreadedGemm().
+// To understand it, first read the code of SingleThreadGemm().
 // The parallelization scheme used here is to have this master function
 // pack a block of RHS and then start worker threads to pack a block of LHS
 // each, and accumulate the corresponding products.
 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
           typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
           MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
-          typename OutputPipelineType>
-void MultiThreadGemm(MultiThreadGemmContext* context, const KernelBase& kernel,
+          typename OutputPipelineType, typename GemmContextType>
+void MultiThreadGemm(GemmContextType* context, const KernelBase& kernel,
                      const MatrixMap<const InputScalar, LhsOrder>& lhs,
                      const MatrixMap<const InputScalar, RhsOrder>& rhs,
                      MatrixMap<OutputScalar, ResultOrder>* result,
@@ -575,12 +627,16 @@
   int cols = result->cols();
   int depth = lhs.cols();
 
+  // zero sizes should have been caught earlier and early-returned.
   assert(rows > 0);
   assert(cols > 0);
   assert(depth > 0);
 
-  const int thread_count =
-      HowManyThreads<KernelFormat::kRows>(context, rows, cols, depth);
+  // The case of rows<cols should have been caught earlier and transposed.
+  assert(rows >= cols);
+
+  const int thread_count = HowManyThreads<KernelFormat::kRows>(
+      context->max_num_threads(), rows, cols, depth);
   if (thread_count == 1) {
     return SingleThreadGemm<KernelFormat, InputScalar, OutputScalar,
                             BitDepthParams>(context, kernel, lhs, rhs, result,
@@ -589,26 +645,22 @@
   }
   assert(thread_count > 1);
 
-  // We choose to use a worker thread for all but one
-  // of the thread workloads. The remaining thread workload will be
-  // executed immediately on the current thread.
-  // In this way, the total number of threads (1 master, N-1 workers)
-  // equals the value returned by HowManyThread. This simple
-  // 1:1 mapping of threads to physical cores, is very important
-  // to getting good multithreaded performance especially for
-  // not-very-large GEMMs, and especially on Android.
-  const int workers_count = thread_count - 1;
+  // Simple 1:1 mapping of tasks to physical cores, which is very important
+  // to getting good multithreaded performance, specially for not-very-large
+  // GEMMs, and especially on Android.
+  const int task_count = thread_count;
 
   Allocator* allocator = context->allocator();
-  WorkersPool* workers_pool = context->workers_pool();
-
-  workers_pool->CreateWorkers(workers_count);
+  auto* workers_pool = context->workers_pool();
 
   BlockParams block_params;
-  block_params.Init<KernelFormat>(rows, cols, depth, workers_count);
+  block_params.Init<KernelFormat>(rows, cols, depth, task_count,
+                                  context->l1_bytes_to_use(),
+                                  context->l2_bytes_to_use(),
+                                  context->l2_rhs_factor());
 
-  PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(
-      Side::Rhs, allocator, block_params);
+  PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(Side::Rhs, allocator,
+                                                         block_params);
   allocator->Commit();
 
   // We loop over large blocks of the RHS.
@@ -616,37 +668,29 @@
     int cs = std::min(block_params.l2_cols, cols - c);
 
     // Pack a large block of the RHS.
-    PackRhs<BitDepthParams>(&packed_rhs, rhs.block(0, c, depth, cs));
+    PackRhs(&packed_rhs, rhs.block(0, c, depth, cs));
 
     // Give work to each worker.
+    std::vector<Task*> tasks;
     int next_start_row = 0;
-    workers_pool->counter_to_decrement_when_ready().Reset(workers_count);
-    for (int thread = 0; thread < thread_count; thread++) {
+    for (int n = 0; n < task_count; ++n) {
       int start_row = next_start_row;
       next_start_row = std::min(rows, RoundUp<KernelFormat::kRows>(
-                                          rows * (thread + 1) / thread_count));
+                                          rows * (n + 1) / task_count));
 
       int block_rows = next_start_row - start_row;
       auto lhs_block = lhs.block(start_row, 0, block_rows, depth);
-      auto result_block = result->block(start_row, c, block_rows, cs);
-      typedef GemmWithPackedRhsTask<KernelFormat, InputScalar, OutputScalar,
-                                    BitDepthParams, LhsOrder, RhsOrder,
-                                    ResultOrder, LhsOffset, RhsOffset,
-                                    OutputPipelineType>
+      typedef GemmWithPackedRhsTask<
+          KernelFormat, InputScalar, OutputScalar, BitDepthParams, LhsOrder,
+          RhsOrder, ResultOrder, LhsOffset, RhsOffset, OutputPipelineType,
+          GemmContextType>
           TaskType;
-      auto task = new TaskType(kernel, lhs_block, packed_rhs, &result_block,
-                               lhs_offset, rhs_offset, output_pipeline);
-      if (thread < workers_count) {
-        workers_pool->StartWorker(thread, task);
-      } else {
-        // Execute the remaining workload immediately on the current thread.
-        task->local_allocator = context->main_thread_task_allocator();
-        task->Run();
-        delete task;
-      }
+      tasks.push_back(new TaskType(context, kernel, lhs_block, packed_rhs, result,
+                                   MatrixBlockBounds(start_row, c, block_rows, cs),
+                                   lhs_offset, rhs_offset, output_pipeline));
     }
-    // Wait for the workers.
-    workers_pool->counter_to_decrement_when_ready().Wait();
+    // Execute the work on the workers (and partially on this thread).
+    workers_pool->Execute(tasks);
   }
 
   allocator->Decommit();
diff --git a/internal/output.h b/internal/output.h
index 28c881a..8ccb8ee 100644
--- a/internal/output.h
+++ b/internal/output.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -23,216 +23,204 @@
 #include <tuple>
 #include <type_traits>
 
+#include "../fixedpoint/fixedpoint.h"
 #include "../public/output_stages.h"
-#include "fixedpoint.h"
+#include "simd_wrappers.h"
 
 namespace gemmlowp {
 
-// A Fragment is a small fixed-size matrix typically stored in one or
-// a few architecture-specific SIMD vectors. Besides plain old scalar types
-// such as int32_t, Fragment types are what can be used as input/output data
-// types for output pipeline stages.
-//
-// More details:
-//
-// In the generic scalar code in this file, we have only implemented
-// evaluation of output stages for scalar inputs (e.g. plain int32_t values).
-// Other files (e.g. output_neon.h) are to provide SIMD paths by implementing
-// evaluation of output stages for SIMD vector types. However, this raises
-// the question of how the different values ("lanes") in a SIMD vector
-// correspond to different values in the whole matrices. For simple entry-wise
-// output stages, this doesn't matter, but for other output stages depending
-// on position within the whole matrix, this does matter. To solve this
-// problem, rather than implementing evaluation of output stages for raw
-// SIMD vector types, we wrap SIMD vector types in "fragment" structs that
-// bring the additional structure of "shape" i.e. mapping SIMD lanes to
-// matrix entries, and we specialize evaluation of output stage for such
-// fragment types. The Fragment template struct here is how we generate
-// all fragment structs. For example, in output_neon.h, it may be specialized
-// with DataType=int32x4_t, Rows=4, Cols=1. MapOrder doesn't matter for
-// vector shapes. While Fragment is only used for SIMD paths, we leave it
-// here in this platform-generic file because this same template should
-// cover the needs of any SIMD architectures.
-template <typename tDataType, int tRows, int tCols, MapOrder tOrder>
-struct Fragment {
-  typedef tDataType DataType;
-  static const int kRows = tRows;
-  static const int kCols = tCols;
-  static const MapOrder kOrder = tOrder;
-
-  Fragment() {}
-  Fragment(const DataType& d) : data(d) {}
-  operator DataType() const { return data; }
-
-  DataType data;
-};
-
-typedef Fragment<std::int32_t, 1, 1, MapOrder::ColMajor> FragmentInt32x1x1;
-typedef Fragment<std::uint8_t, 1, 1, MapOrder::ColMajor> FragmentUint8x1x1;
-
-// OutputStageEvalImpl is the template that we specialize to provide
-// implementations of each output stage for each type of input data.
-//
-// Each specialization provides a OutputType typedef and an Eval function
-// returning OutputType. The OutputType typically depends on the InputType.
-//
-// There are two dimensions in which input data types can vary:
-//   1. Different output stages may expect different data types. The
-//      only hard constraint is that the first stage accepts int32, as
-//      the unpack stage produces int32 accumulators.
-//   2. For a given scalar data type such as int32, there is still the
-//      possibility of having SIMD vector types such as NEON int32x4_t,
-//      typically wrapped as "fragment" types, see struct Fragment.
-//      Thus, there can be several OutputStageEvalImpl
-//      specializations for a single OutputStageType, for different
-//      InputType's.
-template <typename OutputStageType, typename InputType>
-struct OutputStageEvalImpl {
+template <typename OutputStage, typename InputBufferType>
+struct OutputStageEvalBufferImpl {
   // This generic template body should never be hit.
   static_assert(
-      std::is_same<InputType, void>::value,
+      std::is_same<InputBufferType, void>::value,
       "Unimplemented: missing implementation of this output pipeline stage "
       "for this data type. This would happen if some architecture-specific "
       "SIMD back-end (output_$arch.h) were incomplete.");
-
-  OutputStageEvalImpl(const OutputStageType&) {}
 };
 
-// Implementation of OutputStageQuantizeDownInt32ToUint8Scale for scalar data
-template <>
-struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8Scale,
-                           FragmentInt32x1x1> {
-  typedef FragmentInt32x1x1 InputType;
-  typedef FragmentInt32x1x1 OutputType;
-  typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage;
+template <typename OutputStage, typename InputType>
+struct OutputStageEvalImpl {
+  static constexpr int kRows = InputType::kRows;
+  static constexpr int kCols = InputType::kCols;
+  using InputBufferType = typename InputType::BufferType;
+  using BufferEvalImplType =
+      OutputStageEvalBufferImpl<OutputStage, InputBufferType>;
+  using OutputBufferType = typename BufferEvalImplType::OutputType;
+  using OutputScalarType = typename OutputBufferType::ScalarType;
+  using OutputType = RegisterBlock<OutputScalarType, kRows, kCols>;
 
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
+  OutputStageEvalImpl(const OutputStage& s) : buffer_eval_impl(s) {}
 
   OutputType Eval(InputType input, int, int) const {
-    const std::int32_t result_shift = output_stage.result_shift;
+    OutputType output;
+    output.buf = buffer_eval_impl.Eval(input.buf);
+    return output;
+  }
+
+  const BufferEvalImplType buffer_eval_impl;
+};
+
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageQuantizeDownInt32ToUint8Scale,
+                                 RegisterBuffer<std::int32_t, Size>> {
+  using InputType = RegisterBuffer<std::int32_t, Size>;
+  using OutputType = RegisterBuffer<std::int32_t, Size>;
+
+  typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
+
+  OutputType Eval(InputType input) const {
+    const int result_shift = output_stage.result_shift;
     const std::int32_t result_mult_int = output_stage.result_mult_int;
-    const std::int32_t result_offset = output_stage.result_offset;
-    const std::int32_t kRoundingTerm =
-        (result_shift < 1) ? 0 : (1 << (result_shift - 1));
-    return ((input + result_offset) * result_mult_int + kRoundingTerm) >>
-           result_shift;
+    using RegisterType = typename InputType::RegisterType;
+    const RegisterType result_offset =
+        Dup<RegisterType>(output_stage.result_offset);
+    OutputType output;
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      output.reg[i] = RoundingDivideByPOT(
+          Mul(Add(input.reg[i], result_offset), result_mult_int), result_shift);
+    }
+    return output;
   }
 
   const OutputStage& output_stage;
 };
 
-template <>
-struct OutputStageEvalImpl<
-    OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Col>,
-    FragmentInt32x1x1> {
-  typedef FragmentInt32x1x1 InputType;
-  typedef FragmentInt32x1x1 OutputType;
-  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Col>
-      OutputStage;
+template <int Rows, int Cols, VectorShape Shape>
+struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>,
+                           RegisterBlock<std::int32_t, Rows, Cols>> {
+  typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
+  typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
+  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> OutputStage;
 
   OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
 
   OutputType Eval(InputType input, int row, int col) const {
-    const std::int32_t result_shift = output_stage.result_shift;
-    const std::int32_t result_mult_int = output_stage.result_mult_int(row);
-    const std::int32_t result_offset = output_stage.result_offset(row);
-    const std::int32_t kRoundingTerm =
-        (result_shift < 1) ? 0 : (1 << (result_shift - 1));
-    return ((input + result_offset) * result_mult_int + kRoundingTerm) >>
-           result_shift;
+    OutputType output;
+    const int result_shift = output_stage.result_shift;
+    const int pos = Shape == VectorShape::Col ? row : col;
+    const auto result_mult_int =
+        LoadForBroadcasting<InputType>(output_stage.result_mult_int, pos);
+    const auto result_offset =
+        LoadForBroadcasting<InputType>(output_stage.result_offset, pos);
+    const auto dividend = BroadcastMul<InputType>(
+        BroadcastAdd<InputType>(input, result_offset), result_mult_int);
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      output.buf.reg[i] =
+          RoundingDivideByPOT(dividend.buf.reg[i], result_shift);
+    }
+    return output;
   }
 
   const OutputStage& output_stage;
 };
 
-template <>
-struct OutputStageEvalImpl<
-    OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Row>,
-    FragmentInt32x1x1> {
-  typedef FragmentInt32x1x1 InputType;
-  typedef FragmentInt32x1x1 OutputType;
-  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Row>
-      OutputStage;
+template <int Size>
+struct OutputStageEvalBufferImpl<
+    OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
+    RegisterBuffer<std::int32_t, Size>> {
+  typedef RegisterBuffer<std::int32_t, Size> InputType;
+  typedef RegisterBuffer<std::int32_t, Size> OutputType;
 
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
+  typedef OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint OutputStage;
 
-  OutputType Eval(InputType input, int row, int col) const {
-    const std::int32_t result_shift = output_stage.result_shift;
-    const std::int32_t result_mult_int = output_stage.result_mult_int(col);
-    const std::int32_t result_offset = output_stage.result_offset(col);
-    const std::int32_t kRoundingTerm =
-        (result_shift < 1) ? 0 : (1 << (result_shift - 1));
-    return ((input + result_offset) * result_mult_int + kRoundingTerm) >>
-           result_shift;
+  OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    using RegisterType = typename InputType::RegisterType;
+    const RegisterType result_offset_after_shift =
+        Dup<RegisterType>(output_stage.result_offset_after_shift);
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul(
+          input.reg[i], output_stage.result_fixedpoint_multiplier);
+      output.reg[i] =
+          Add(RoundingDivideByPOT(mulhigh_val, output_stage.result_shift),
+              result_offset_after_shift);
+    }
+    return output;
   }
 
   const OutputStage& output_stage;
 };
 
 // Implementation of OutputStageSaturatingCastToUint8 for scalar data
-template <>
-struct OutputStageEvalImpl<OutputStageSaturatingCastToUint8,
-                           FragmentInt32x1x1> {
-  typedef FragmentInt32x1x1 InputType;
-  typedef FragmentUint8x1x1 OutputType;
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegisterBuffer<std::int32_t, Size>> {
+  typedef RegisterBuffer<std::int32_t, Size> InputType;
+  typedef RegisterBuffer<std::uint8_t, Size> OutputType;
+  static_assert(InputType::kRegisterLanes == 1,
+                "This path is only for scalar values");
+
   typedef OutputStageSaturatingCastToUint8 OutputStage;
 
-  OutputStageEvalImpl(const OutputStage&) {}
+  OutputStageEvalBufferImpl(const OutputStage&) {}
 
-  OutputType Eval(InputType input, int, int) const {
-    std::int32_t data = input.data;
-    return data > 255 ? 255 : data < 0 ? 0 : data;
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      std::int32_t data = input.reg[i];
+      output.reg[i] = data > 255 ? 255 : data < 0 ? 0 : data;
+    }
+    return output;
   }
 };
 
-// Implementation of OutputStageBiasAddition for scalar data
-template <typename VectorType>
+template <int Rows, int Cols, typename VectorType>
 struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
-                           FragmentInt32x1x1> {
-  typedef FragmentInt32x1x1 InputType;
-  typedef FragmentInt32x1x1 OutputType;
+                           RegisterBlock<std::int32_t, Rows, Cols>> {
+  typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
+  typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
   typedef OutputStageBiasAddition<VectorType> OutputStage;
 
   OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
 
   OutputType Eval(InputType input, int row, int col) const {
-    if (VectorType::kShape == VectorShape::Row) {
-      return input + output_stage.bias_vector(col);
-    } else {
-      return input + output_stage.bias_vector(row);
-    }
+    const int pos = VectorType::kShape == VectorShape::Row ? col : row;
+    return BroadcastAdd<InputType>(
+        input, LoadForBroadcasting<InputType>(output_stage.bias_vector, pos));
   }
 
   const OutputStage& output_stage;
 };
 
-// Implementation of OutputStageClamp for scalar data
-template <>
-struct OutputStageEvalImpl<OutputStageClamp, FragmentInt32x1x1> {
-  typedef FragmentInt32x1x1 InputType;
-  typedef FragmentInt32x1x1 OutputType;
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageClamp,
+                                 RegisterBuffer<std::int32_t, Size>> {
+  typedef RegisterBuffer<std::int32_t, Size> InputType;
+  typedef RegisterBuffer<std::int32_t, Size> OutputType;
+
   typedef OutputStageClamp OutputStage;
 
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
+  OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
 
-  OutputType Eval(InputType input, int, int) const {
-    const std::int32_t min = output_stage.min;
-    const std::int32_t max = output_stage.max;
-    return std::min(std::max(input.data, min), max);
+  OutputType Eval(InputType input) const {
+    using RegisterType = typename InputType::RegisterType;
+    const RegisterType min = Dup<RegisterType>(output_stage.min);
+    const RegisterType max = Dup<RegisterType>(output_stage.max);
+    OutputType output;
+    for (int i = 0; i < InputType::kRegisterCount; i++) {
+      output.reg[i] = Min(Max(input.reg[i], min), max);
+    }
+    return output;
   }
 
   const OutputStage& output_stage;
 };
 
-// Implementation of OutputStageTanh for either scalar or SIMD data
-template <typename tInputType>
-struct OutputStageTanhEvalImpl {
-  typedef tInputType InputType;
-  typedef InputType OutputType;
-  typedef typename InputType::DataType DataType;
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageTanh,
+                                 RegisterBuffer<std::int32_t, Size>> {
+  typedef RegisterBuffer<std::int32_t, Size> InputType;
+  typedef RegisterBuffer<std::int32_t, Size> OutputType;
+  using RegisterType = typename InputType::RegisterType;
+  typedef RegisterType DataType;
   typedef OutputStageTanh OutputStage;
 
-  OutputStageTanhEvalImpl(const OutputStage& s) : output_stage(s) {
+  OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {
     const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32;
     const std::int32_t real_amplitude_as_int32 =
         output_stage.real_amplitude_as_int32;
@@ -248,8 +236,8 @@
       inverse_amplitude_normalized_double *= 2;
       inverse_amplitude_neg_exponent++;
     }
-    inverse_amplitude_normalized =
-        ToFixedPoint<DataType, 0>(inverse_amplitude_normalized_double);
+    inverse_amplitude_normalized = FixedPoint<DataType, 0>::FromDouble(
+        inverse_amplitude_normalized_double);
 
     double amplitude_normalized_double = real_amplitude_as_int32;
     amplitude_exponent = 0;
@@ -258,39 +246,44 @@
       amplitude_exponent++;
     }
     amplitude_normalized =
-        ToFixedPoint<DataType, 0>(amplitude_normalized_double);
+        FixedPoint<DataType, 0>::FromDouble(amplitude_normalized_double);
   }
 
-  OutputType Eval(InputType input, int, int) const {
+  OutputType Eval(InputType input) const {
     const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32;
 
     typedef FixedPoint<DataType, 3> F3;
     typedef FixedPoint<DataType, 0> F0;
 
-    // fixed-point affine transformation
-    DataType input_centered =
-        Sub(input.data, Dup<DataType>(real_zero_as_int32));
-    F3 fixedpoint_input =
-        F3::FromRaw(input_centered) * inverse_amplitude_normalized;
-    // left shift
-    fixedpoint_input.raw() =
-        ShiftLeft(fixedpoint_input.raw(), 28 - inverse_amplitude_neg_exponent);
-    // fixed-point tanh and multiplication
-    F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized;
-    // right shift
-    DataType int32_output =
-        Add(Dup<DataType>(real_zero_as_int32),
-            ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent));
+    OutputType output;
 
-    DataType mask_if_below_cutoff_min =
-        MaskIfLessThanOrEqual(input.data, Dup<DataType>(input_cutoff_min));
-    DataType mask_if_above_cutoff_max =
-        MaskIfGreaterThanOrEqual(input.data, Dup<DataType>(input_cutoff_max));
+    for (int i = 0; i < OutputType::kRegisterCount; i++) {
+      // fixed-point affine transformation
+      DataType input_centered =
+          Sub(input.reg[i], Dup<DataType>(real_zero_as_int32));
+      F3 fixedpoint_input =
+          F3::FromRaw(input_centered) * inverse_amplitude_normalized;
+      // left shift
+      fixedpoint_input.raw() = ShiftLeft(fixedpoint_input.raw(),
+                                         28 - inverse_amplitude_neg_exponent);
+      // fixed-point tanh and multiplication
+      F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized;
+      // right shift
+      DataType int32_output =
+          Add(Dup<DataType>(real_zero_as_int32),
+              ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent));
 
-    return SelectUsingMask(
-        mask_if_below_cutoff_min, Dup<DataType>(output_min),
-        SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max),
-                        int32_output));
+      DataType mask_if_below_cutoff_min =
+          MaskIfLessThanOrEqual(input.reg[i], Dup<DataType>(input_cutoff_min));
+      DataType mask_if_above_cutoff_max = MaskIfGreaterThanOrEqual(
+          input.reg[i], Dup<DataType>(input_cutoff_max));
+
+      output.reg[i] = SelectUsingMask(
+          mask_if_below_cutoff_min, Dup<DataType>(output_min),
+          SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max),
+                          int32_output));
+    }
+    return output;
   }
 
   const OutputStage& output_stage;
@@ -302,13 +295,6 @@
   int amplitude_exponent;
 };
 
-template <>
-struct OutputStageEvalImpl<OutputStageTanh, FragmentInt32x1x1>
-    : OutputStageTanhEvalImpl<FragmentInt32x1x1> {
-  OutputStageEvalImpl(const OutputStageTanh& output_stage)
-      : OutputStageTanhEvalImpl(output_stage) {}
-};
-
 // OutputPipelineOutputType is a helper to determine the output data type of a
 // pipeline, for a
 // given input data type. It is a recursive template; see the explanation on
@@ -377,13 +363,32 @@
   }
 };
 
+template <typename RegisterBlockType, typename DstType>
+struct StoreFinalOutputImpl {
+  static_assert(std::is_same<RegisterBlockType, void>::value,
+                "This generic impl should never be hit");
+};
+
+template <typename ScalarType, int Rows, int Cols, typename DstType>
+struct StoreFinalOutputImpl<RegisterBlock<ScalarType, Rows, Cols>, DstType> {
+  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+  static void Run(const RegisterBlockType& src, DstType* dst, int row,
+                  int col) {
+    for (int r = 0; r < Rows; r++) {
+      for (int c = 0; c < Cols; c++) {
+        *dst->data(row + r, col + c) = src.buf.reg[r + c * Rows];
+      }
+    }
+  }
+};
+
 // StoreFinalOutput takes the final value at the end of the output pipeline and
 // stores it into the destination matrix. It can be specialized for different
 // data types; the generic implementation here is typically used only for plain
 // old scalar (not SIMD) types.
-template <typename OutputType, typename DstType>
-void StoreFinalOutput(OutputType value, DstType* dst, int row, int col) {
-  *dst->data(row, col) = value;
+template <typename RegisterBlockType, typename DstType>
+void StoreFinalOutput(RegisterBlockType src, DstType* dst, int row, int col) {
+  StoreFinalOutputImpl<RegisterBlockType, DstType>::Run(src, dst, row, col);
 }
 
 template <typename OutputPipelineType, typename InputType>
@@ -396,20 +401,23 @@
   // result
   // of the unpack stage and stores it into the destination matrix.
   template <typename DstType>
-  void Execute(InputType input, DstType* dst, int row, int col) {
+  void Execute(InputType input, DstType* dst, int src_global_row,
+               int src_global_col, int dst_row, int dst_col) const {
     // Statically assert that the output pipeline matches the given destination
     // matrix's scalar type.
-    typedef typename OutputPipelineOutputType<OutputPipelineType, 0,
-                                              FragmentInt32x1x1>::Type::DataType
+    typedef typename OutputPipelineOutputType<
+        OutputPipelineType, 0, InputType>::Type::BufferType::ScalarType
+
         ScalarOutputType;
     typedef typename DstType::Scalar ScalarDstType;
     static_assert(std::is_same<ScalarOutputType, ScalarDstType>::value,
                   "mismatched destination scalar type and output pipeline");
 
     // Evaluate the output pipeline.
-    auto output = output_pipeline_eval_impl_.Eval(input, row, col);
+    auto output =
+        output_pipeline_eval_impl_.Eval(input, src_global_row, src_global_col);
     // Store the result into the destination matrix.
-    StoreFinalOutput(output, dst, row, col);
+    StoreFinalOutput(output, dst, dst_row, dst_col);
   }
 
   const OutputPipelineEvalImpl<OutputPipelineType, 0, InputType>
@@ -418,4 +426,10 @@
 
 }  // namespace gemmlowp
 
+#ifdef GEMMLOWP_NEON
+#include "output_neon.h"
+#elif defined(GEMMLOWP_SSE4)
+#include "output_sse.h"
+#endif
+
 #endif  // GEMMLOWP_INTERNAL_OUTPUT_H_
diff --git a/internal/output_neon.h b/internal/output_neon.h
index ed5f57c..7e111e5 100644
--- a/internal/output_neon.h
+++ b/internal/output_neon.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -23,257 +23,410 @@
 
 namespace gemmlowp {
 
-// Definitions of Fragment types wrapping NEON vector types.
-typedef Fragment<int32x4_t, 4, 1, MapOrder::ColMajor> NEONFragmentInt32x4x1;
-typedef Fragment<int32x4x4_t, 16, 1, MapOrder::ColMajor> NEONFragmentInt32x16x1;
-typedef Fragment<uint8x8_t, 4, 1, MapOrder::ColMajor> NEONFragmentUint8x4x1;
-typedef Fragment<uint8x16_t, 16, 1, MapOrder::ColMajor> NEONFragmentUint8x16x1;
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<4>> {
+  typedef RegBufferInt32<4> InputType;
+  typedef RegBufferUint8<4> OutputType;
 
-// The code in unpack_neon.h will whenever possible process
-// 16 entries at once (4 SIMD vectors of 4 entries each at once),
-// to offer the compiler better optimization opportunities, reducing
-// register dependencies. From the perspective of interfacing with the output
-// pipeline, this takes the form of passing Fragment types wrapping int32x4x4_t
-// data. In most cases, such data is handled simply by handling separately its
-// 4 int32x4_t components. This partial specialization handles that for
-// arbitrary output stages implementing a int32x4_t path. Only some output
-// stages below will override this to use custom code to handle int32x4x4_t
-// data all at once (see OutputStageSaturatingCastToUint8 below).
-template <typename OutputStageType>
-struct OutputStageEvalImpl<OutputStageType, NEONFragmentInt32x16x1> {
-  typedef NEONFragmentInt32x16x1 InputType;
-  typedef NEONFragmentInt32x16x1 OutputType;
-  typedef OutputStageEvalImpl<OutputStageType, NEONFragmentInt32x4x1>
-      ImplInt32x4;
-  OutputStageEvalImpl(const OutputStageType& s) : impl_int32x4(s) {}
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
 
-  OutputType Eval(InputType input, int row, int col) const {
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
     OutputType output;
+    int16x4_t res_16 = vqmovn_s32(input.reg[0]);
+    uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16));
+    output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0);
+    return output;
+  }
+};
 
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<8>> {
+  typedef RegBufferInt32<8> InputType;
+  typedef RegBufferUint8<8> OutputType;
+
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    int16x8_t res_16 =
+        vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+    output.reg[0] = vqmovun_s16(res_16);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<16>> {
+  typedef RegBufferInt32<16> InputType;
+  typedef RegBufferUint8<16> OutputType;
+
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    int16x8_t res_16_0 =
+        vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+    int16x8_t res_16_1 =
+        vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
+    output.reg[0] = vqmovun_s16(res_16_0);
+    output.reg[1] = vqmovun_s16(res_16_1);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<32>> {
+  typedef RegBufferInt32<32> InputType;
+  typedef RegBufferUint8<32> OutputType;
+
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    int16x8_t res_16[4];
     for (int i = 0; i < 4; i++) {
-      output.data.val[i] =
-          impl_int32x4.Eval(input.data.val[i], row + 4 * i, col);
+      res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
+                               vqmovn_s32(input.reg[2 * i + 1]));
+    }
+    for (int i = 0; i < 4; i++) {
+      output.reg[i] = vqmovun_s16(res_16[i]);
     }
     return output;
   }
-
-  ImplInt32x4 impl_int32x4;
 };
 
-// Implementation of OutputStageQuantizeDownInt32ToUint8Scale for
-// NEONFragmentInt32x4x1
-template <>
-struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8Scale,
-                           NEONFragmentInt32x4x1> {
-  typedef NEONFragmentInt32x4x1 InputType;
-  typedef NEONFragmentInt32x4x1 OutputType;
-  typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage;
-
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
-
-  OutputType Eval(InputType input, int, int) const {
-    const std::int32_t result_shift = output_stage.result_shift;
-    const std::int32_t result_mult_int = output_stage.result_mult_int;
-    const std::int32_t result_offset = output_stage.result_offset;
-    const std::int32_t preshift_offset =
-        (result_shift < 1) ? 0 : (1 << (result_shift - 1));
-    const int32x4_t a = vaddq_s32(input, vdupq_n_s32(result_offset));
-    const int32x4_t b =
-        vmlaq_n_s32(vdupq_n_s32(preshift_offset), a, result_mult_int);
-    return vshlq_s32(b, vdupq_n_s32(-result_shift));
-  }
-
-  const OutputStage& output_stage;
-};
-
-// Implementation of OutputStageQuantizeDownInt32ToUint8ScalePC for
-// NEONFragmentInt32x4x1
-template <>
-struct OutputStageEvalImpl<
-    OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Col>,
-    NEONFragmentInt32x4x1> {
-  typedef NEONFragmentInt32x4x1 InputType;
-  typedef NEONFragmentInt32x4x1 OutputType;
-  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Col>
-      OutputStage;
-
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
-
-  OutputType Eval(InputType input, int row, int col) const {
-    const std::int32_t result_shift = output_stage.result_shift;
-    const std::int32_t preshift_offset =
-        (result_shift < 1) ? 0 : (1 << (result_shift - 1));
-    const int32x4_t result_mult_int =
-        vld1q_s32(output_stage.result_mult_int.data(row));
-    const int32x4_t result_offset =
-        vld1q_s32(output_stage.result_offset.data(row));
-    const int32x4_t a = vaddq_s32(input, result_offset);
-    const int32x4_t b =
-        vmlaq_s32(vdupq_n_s32(preshift_offset), a, result_mult_int);
-    return vshlq_s32(b, vdupq_n_s32(-result_shift));
-  }
-
-  const OutputStage& output_stage;
-};
-
-// Implementation of OutputStageQuantizeDownInt32ToUint8ScalePC for
-// NEONFragmentInt32x4x1
-template <>
-struct OutputStageEvalImpl<
-    OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Row>,
-    NEONFragmentInt32x4x1> {
-  typedef NEONFragmentInt32x4x1 InputType;
-  typedef NEONFragmentInt32x4x1 OutputType;
-  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<VectorShape::Row>
-      OutputStage;
-
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
-
-  OutputType Eval(InputType input, int row, int col) const {
-    const std::int32_t result_shift = output_stage.result_shift;
-    const std::int32_t preshift_offset =
-        (result_shift < 1) ? 0 : (1 << (result_shift - 1));
-    const int32x4_t result_mult_int =
-        vld1q_s32(output_stage.result_mult_int.data(col));
-    const int32x4_t result_offset =
-        vld1q_s32(output_stage.result_offset.data(row));
-    const int32x4_t a = vaddq_s32(input, result_offset);
-    const int32x4_t b =
-        vmlaq_s32(vdupq_n_s32(preshift_offset), a, result_mult_int);
-    return vshlq_s32(b, vdupq_n_s32(-result_shift));
-  }
-
-  const OutputStage& output_stage;
-};
-
-// Implementation of OutputStageSaturatingCastToUint8 for NEONFragmentInt32x4x1
-template <>
-struct OutputStageEvalImpl<OutputStageSaturatingCastToUint8,
-                           NEONFragmentInt32x4x1> {
-  typedef NEONFragmentInt32x4x1 InputType;
-  typedef NEONFragmentUint8x4x1 OutputType;
-  typedef OutputStageSaturatingCastToUint8 OutputStage;
-
-  OutputStageEvalImpl(const OutputStage&) {}
-
-  OutputType Eval(InputType input, int, int) const {
-    int16x8_t q16 = vcombine_s16(vqmovn_s32(input), vdup_n_s16(0));
-    return vqmovun_s16(q16);
-  }
-};
-
-// In the case of OutputStageSaturatingCastToUint8, the handling of
-// NEONFragmentInt32x16x1 data can be made much more efficient by handling
-// it all at once, instead of as 4 separate int32x4 values as in the above
-// generic partial specialization. This also avoids the poor (50%) register
-// utilization of FragmentUint8x4x1: by handling 16 scalar values at once,
-// we are able to fill a uint8x16_t.
-template <>
-struct OutputStageEvalImpl<OutputStageSaturatingCastToUint8,
-                           NEONFragmentInt32x16x1> {
-  typedef NEONFragmentInt32x16x1 InputType;
-  typedef NEONFragmentUint8x16x1 OutputType;
-  typedef OutputStageSaturatingCastToUint8 OutputStage;
-
-  OutputStageEvalImpl(const OutputStage&) {}
-
-  OutputType Eval(InputType input, int, int) const {
-    int16x8_t q16[2];
-    for (int i = 0; i < 2; i++) {
-      q16[i] = vcombine_s16(vqmovn_s32(input.data.val[2 * i]),
-                            vqmovn_s32(input.data.val[2 * i + 1]));
-    }
-    return vcombine_u8(vqmovun_s16(q16[0]), vqmovun_s16(q16[1]));
-  }
-};
-
-// Implementation of OutputStageBiasAddition for NEONFragmentInt32x4x1
-template <typename VectorType>
-struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
-                           NEONFragmentInt32x4x1> {
-  typedef NEONFragmentInt32x4x1 InputType;
-  typedef NEONFragmentInt32x4x1 OutputType;
-  typedef OutputStageBiasAddition<VectorType> OutputStage;
-
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
-
-  OutputType Eval(InputType input, int row, int col) const {
-    int32x4_t bias;
-    if (VectorType::kShape == VectorShape::Row) {
-      bias = vdupq_n_s32(output_stage.bias_vector(col));
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
+  static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+      StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
     } else {
-      bias = vld1q_s32(output_stage.bias_vector.data(row));
+      *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
+      *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
+      *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
+      *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
+      *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
+      *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
+      *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
+      *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
     }
-    return vaddq_s32(input, bias);
   }
-
-  const OutputStage& output_stage;
 };
 
-// Implementation of OutputStageClamp for NEONFragmentInt32x4x1
-template <>
-struct OutputStageEvalImpl<OutputStageClamp, NEONFragmentInt32x4x1> {
-  typedef NEONFragmentInt32x4x1 InputType;
-  typedef NEONFragmentInt32x4x1 OutputType;
-  typedef OutputStageClamp OutputStage;
+inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
+  const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]);
+  const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]);
+  RegBlockInt32<4, 4> result;
+  result.buf.reg[0] =
+      vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0]));
+  result.buf.reg[1] =
+      vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1]));
+  result.buf.reg[2] =
+      vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0]));
+  result.buf.reg[3] =
+      vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1]));
+  return result;
+}
 
-  OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
-
-  OutputType Eval(InputType input, int, int) const {
-    const int32x4_t min = vdupq_n_s32(output_stage.min);
-    const int32x4_t max = vdupq_n_s32(output_stage.max);
-    return vminq_s32(vmaxq_s32(input, min), max);
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
+  static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
+                  int col) {
+    const auto& block =
+        DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
+    std::int32_t* dst_ptr = dst->data(row, col);
+    int stride = dst->stride();
+    for (int i = 0; i < 4; i++) {
+      vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]);
+    }
   }
-
-  const OutputStage& output_stage;
 };
 
-// Implementation of OutputStageTanh for NEONFragmentInt32x4x1
-template <>
-struct OutputStageEvalImpl<OutputStageTanh, NEONFragmentInt32x4x1>
-    : OutputStageTanhEvalImpl<NEONFragmentInt32x4x1> {
-  OutputStageEvalImpl(const OutputStageTanh& output_stage)
-      : OutputStageTanhEvalImpl(output_stage) {}
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
+  static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::int32_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      int col_stride = dst->cols_stride();
+      for (int i = 0; i < 4; i++) {
+        vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]);
+        vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
+      }
+    } else {
+      int row_stride = dst->rows_stride();
+      RegBlockInt32<4, 4> top;
+      top.buf.reg[0] = src.buf.reg[0];
+      top.buf.reg[1] = src.buf.reg[2];
+      top.buf.reg[2] = src.buf.reg[4];
+      top.buf.reg[3] = src.buf.reg[6];
+      const auto transpose_top = Transpose(top);
+      for (int i = 0; i < 4; i++) {
+        vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> bottom;
+      bottom.buf.reg[0] = src.buf.reg[1];
+      bottom.buf.reg[1] = src.buf.reg[3];
+      bottom.buf.reg[2] = src.buf.reg[5];
+      bottom.buf.reg[3] = src.buf.reg[7];
+      const auto transpose_bottom = Transpose(bottom);
+      for (int i = 0; i < 4; i++) {
+        vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]);
+      }
+    }
+  }
 };
 
-// Specialization of StoreFinalOutput for NEONFragmentUint8x4x1.
-// This is quite inefficient, but we have no choice: instructions storing 32bit
-// at once also assume 32bit alignment. In practice, this slowness is not a
-// problem because we use the x16 path for most values.
 template <typename DstType>
-inline void StoreFinalOutput(NEONFragmentUint8x4x1 value, DstType* dst, int row,
-                             int col) {
-  vst1_lane_u8(dst->data(row + 0, col), value, 0);
-  vst1_lane_u8(dst->data(row + 1, col), value, 1);
-  vst1_lane_u8(dst->data(row + 2, col), value, 2);
-  vst1_lane_u8(dst->data(row + 3, col), value, 3);
-}
-
-// Specialization of StoreFinalOutput for NEONFragmentUint8x16x1.
-template <typename DstType>
-inline void StoreFinalOutput(NEONFragmentUint8x16x1 value, DstType* dst,
-                             int row, int col) {
-  vst1q_u8(dst->data(row, col), value);
-}
-
-// Specialization of StoreFinalOutput for NEONFragmentInt32x4x1, storing into a
-// int32 destination.
-template <typename DstType>
-inline void StoreFinalOutput(NEONFragmentInt32x4x1 value, DstType* dst, int row,
-                             int col) {
-  vst1q_s32(dst->data(row, col), value);
-}
-
-// Specialization of StoreFinalOutput for NEONFragmentInt32x16x1, storing into
-// a int32 destination.
-template <typename DstType>
-inline void StoreFinalOutput(NEONFragmentInt32x16x1 value, DstType* dst,
-                             int row, int col) {
-  for (int i = 0; i < 4; i++) {
-    vst1q_s32(dst->data(row + 4 * i, col), value.data.val[i]);
+struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
+  static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
+                  int col) {
+    std::int32_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      int col_stride = dst->cols_stride();
+      for (int i = 0; i < 8; i++) {
+        vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]);
+        vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
+      }
+    } else {
+      int row_stride = dst->rows_stride();
+      RegBlockInt32<4, 4> top_left;
+      top_left.buf.reg[0] = src.buf.reg[0];
+      top_left.buf.reg[1] = src.buf.reg[2];
+      top_left.buf.reg[2] = src.buf.reg[4];
+      top_left.buf.reg[3] = src.buf.reg[6];
+      const auto transpose_top_left = Transpose(top_left);
+      for (int i = 0; i < 4; i++) {
+        vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> bottom_left;
+      bottom_left.buf.reg[0] = src.buf.reg[1];
+      bottom_left.buf.reg[1] = src.buf.reg[3];
+      bottom_left.buf.reg[2] = src.buf.reg[5];
+      bottom_left.buf.reg[3] = src.buf.reg[7];
+      const auto transpose_bottom_left = Transpose(bottom_left);
+      for (int i = 0; i < 4; i++) {
+        vst1q_s32(dst_ptr + (i + 4) * row_stride,
+                  transpose_bottom_left.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> top_right;
+      top_right.buf.reg[0] = src.buf.reg[8];
+      top_right.buf.reg[1] = src.buf.reg[10];
+      top_right.buf.reg[2] = src.buf.reg[12];
+      top_right.buf.reg[3] = src.buf.reg[14];
+      const auto transpose_top_right = Transpose(top_right);
+      for (int i = 0; i < 4; i++) {
+        vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> bottom_right;
+      bottom_right.buf.reg[0] = src.buf.reg[9];
+      bottom_right.buf.reg[1] = src.buf.reg[11];
+      bottom_right.buf.reg[2] = src.buf.reg[13];
+      bottom_right.buf.reg[3] = src.buf.reg[15];
+      const auto transpose_bottom_right = Transpose(bottom_right);
+      for (int i = 0; i < 4; i++) {
+        vst1q_s32(dst_ptr + (i + 4) * row_stride + 4,
+                  transpose_bottom_right.buf.reg[i]);
+      }
+    }
   }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
+  static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
+                  int col) {
+    std::int32_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      vst1q_s32(dst_ptr, src.buf.reg[0]);
+    } else {
+      int row_stride = dst->rows_stride();
+      vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
+      vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
+      vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
+      vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
+  static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::int32_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::RowMajor) {
+      vst1q_s32(dst_ptr, src.buf.reg[0]);
+    } else {
+      int col_stride = dst->cols_stride();
+      vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
+      vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
+      vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
+      vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
+  static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
+                  int col) {
+    const std::uint32_t src_reg = src.buf.reg[0];
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row + i, col) = (src_reg >> (8 * i));
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
+  static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
+                  int col) {
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
+  static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
+                  int col) {
+    std::uint8_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      vst1_u8(dst_ptr, src.buf.reg[0]);
+    } else {
+      const int row_stride = dst->rows_stride();
+      vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
+      vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
+      vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
+      vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
+      vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
+      vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
+      vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
+      vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
+  static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::uint8_t* dst_ptr = dst->data(row, col);
+    const int row_stride = dst->rows_stride();
+    const int col_stride = dst->cols_stride();
+    for (int i = 0; i < 2; i++) {
+      vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 0);
+      vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 1);
+      vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 2);
+      vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
+                   src.buf.reg[i], 3);
+      vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 4);
+      vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 5);
+      vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 6);
+      vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
+                   src.buf.reg[i], 7);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
+  static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::uint8_t* dst_ptr = dst->data(row, col);
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      int col_stride = dst->cols_stride();
+      for (int i = 0; i < 4; i++) {
+        vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]);
+      }
+    } else {
+      for (int i = 0; i < 4; i++) {
+        int row_stride = dst->rows_stride();
+        std::uint8_t* col_ptr = dst_ptr + i;
+        vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
+        vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
+        vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
+        vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
+        vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
+        vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
+        vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
+        vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
+      }
+    }
+  }
+};
+
+inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) {
+  uint8x8x2_t a[4];
+  a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]);
+  a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]);
+  a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]);
+  a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]);
+  uint16x4x2_t b[4];
+  b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]),
+                  vreinterpret_u16_u8(a[1].val[0]));
+  b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]),
+                  vreinterpret_u16_u8(a[1].val[1]));
+  b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]),
+                  vreinterpret_u16_u8(a[3].val[0]));
+  b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]),
+                  vreinterpret_u16_u8(a[3].val[1]));
+  uint32x2x2_t c[4];
+  c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]),
+                  vreinterpret_u32_u16(b[2].val[0]));
+  c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]),
+                  vreinterpret_u32_u16(b[3].val[0]));
+  c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]),
+                  vreinterpret_u32_u16(b[2].val[1]));
+  c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]),
+                  vreinterpret_u32_u16(b[3].val[1]));
+  RegBlockUint8<8, 8> result;
+  result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]);
+  result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]);
+  result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]);
+  result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]);
+  result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]);
+  result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]);
+  result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]);
+  result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]);
+  return result;
 }
 
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
+  static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
+                  int col) {
+    const auto& block =
+        DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
+    std::uint8_t* dst_ptr = dst->data(row, col);
+    int stride = dst->stride();
+    for (int i = 0; i < 8; i++) {
+      vst1_u8(dst_ptr + i * stride, block.buf.reg[i]);
+    }
+  }
+};
+
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
diff --git a/internal/output_sse.h b/internal/output_sse.h
new file mode 100644
index 0000000..5c06253
--- /dev/null
+++ b/internal/output_sse.h
@@ -0,0 +1,354 @@
+// Copyright 2015 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// output_sse.h: optimized SSE4.2 specializations of the templates in output.h.
+
+#ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
+#define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
+
+#include "output.h"
+
+#include <smmintrin.h>
+
+namespace gemmlowp {
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<4>> {
+  typedef RegBufferInt32<4> InputType;
+  typedef RegBufferUint8<4> OutputType;
+
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
+    __m128i res_8 = _mm_packus_epi16(res_16, res_16);
+    output.reg[0] = _mm_cvtsi128_si32(res_8);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<8>> {
+  typedef RegBufferInt32<8> InputType;
+  typedef RegBufferUint8<8> OutputType;
+
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]);
+    __m128i res_8 = _mm_packus_epi16(res_16, res_16);
+    output.reg[0] = _mm_extract_epi32(res_8, 0);
+    output.reg[1] = _mm_extract_epi32(res_8, 1);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<16>> {
+  typedef RegBufferInt32<16> InputType;
+  typedef RegBufferUint8<16> OutputType;
+
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
+    __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
+    output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
+    return output;
+  }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+                                 RegBufferInt32<32>> {
+  typedef RegBufferInt32<32> InputType;
+  typedef RegBufferUint8<32> OutputType;
+
+  typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+  OutputStageEvalBufferImpl(const OutputStage&) {}
+
+  OutputType Eval(InputType input) const {
+    OutputType output;
+    __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
+    __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
+    output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
+    __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]);
+    __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]);
+    output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3);
+    return output;
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
+  static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+    } else {
+      *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
+      *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
+      *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
+      *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
+  static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+      StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
+    } else {
+      *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
+      *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
+      *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
+      *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
+      *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
+      *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
+      *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
+      *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
+    }
+  }
+};
+
+inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
+  __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]);
+  __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]);
+  __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]);
+  __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]);
+
+  RegBlockInt32<4, 4> result;
+  result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1);
+  result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1);
+  result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3);
+  result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3);
+  return result;
+}
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
+  static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
+      }
+    } else {
+      const auto transpose = Transpose(src);
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
+      }
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
+  static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
+        StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
+      }
+    } else {
+      RegBlockInt32<4, 4> top;
+      top.buf.reg[0] = src.buf.reg[0];
+      top.buf.reg[1] = src.buf.reg[2];
+      top.buf.reg[2] = src.buf.reg[4];
+      top.buf.reg[3] = src.buf.reg[6];
+      const auto transpose_top = Transpose(top);
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> bottom;
+      bottom.buf.reg[0] = src.buf.reg[1];
+      bottom.buf.reg[1] = src.buf.reg[3];
+      bottom.buf.reg[2] = src.buf.reg[5];
+      bottom.buf.reg[3] = src.buf.reg[7];
+      const auto transpose_bottom = Transpose(bottom);
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
+      }
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
+  static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      for (int i = 0; i < 8; i++) {
+        StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
+        StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
+      }
+    } else {
+      RegBlockInt32<4, 4> top_left;
+      top_left.buf.reg[0] = src.buf.reg[0];
+      top_left.buf.reg[1] = src.buf.reg[2];
+      top_left.buf.reg[2] = src.buf.reg[4];
+      top_left.buf.reg[3] = src.buf.reg[6];
+      const auto transpose_top_left = Transpose(top_left);
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> bottom_left;
+      bottom_left.buf.reg[0] = src.buf.reg[1];
+      bottom_left.buf.reg[1] = src.buf.reg[3];
+      bottom_left.buf.reg[2] = src.buf.reg[5];
+      bottom_left.buf.reg[3] = src.buf.reg[7];
+      const auto transpose_bottom_left = Transpose(bottom_left);
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row + 4 + i, col),
+                     transpose_bottom_left.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> top_right;
+      top_right.buf.reg[0] = src.buf.reg[8];
+      top_right.buf.reg[1] = src.buf.reg[10];
+      top_right.buf.reg[2] = src.buf.reg[12];
+      top_right.buf.reg[3] = src.buf.reg[14];
+      const auto transpose_top_right = Transpose(top_right);
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row + i, col + 4),
+                     transpose_top_right.buf.reg[i]);
+      }
+      RegBlockInt32<4, 4> bottom_right;
+      bottom_right.buf.reg[0] = src.buf.reg[9];
+      bottom_right.buf.reg[1] = src.buf.reg[11];
+      bottom_right.buf.reg[2] = src.buf.reg[13];
+      bottom_right.buf.reg[3] = src.buf.reg[15];
+      const auto transpose_bottom_right = Transpose(bottom_right);
+      for (int i = 0; i < 4; i++) {
+        StoreInt32x4(dst->data(row + 4 + i, col + 4),
+                     transpose_bottom_right.buf.reg[i]);
+      }
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
+  static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
+                  int col) {
+    if (DstType::kOrder == MapOrder::ColMajor) {
+      *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
+      *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
+      *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
+      *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
+    } else {
+      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
+  static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
+                  int col) {
+    const std::uint32_t src_reg = src.buf.reg[0];
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row + i, col) = (src_reg >> (8 * i));
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
+  static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
+                  int col) {
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
+    }
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
+  static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
+                  int col) {
+    for (int i = 0; i < 4; i++) {
+      *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
+  static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::uint8_t buf[16];
+    StoreUint8x16(buf, src.buf.reg[0]);
+    for (int c = 0; c < 4; c++) {
+      for (int r = 0; r < 4; r++) {
+        *dst->data(row + r, col + c) = buf[r + 4 * c];
+      }
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
+  static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
+                  int col) {
+    std::uint8_t buf[32];
+    StoreUint8x16(buf, src.buf.reg[0]);
+    StoreUint8x16(buf + 16, src.buf.reg[1]);
+    for (int c = 0; c < 4; c++) {
+      for (int r = 0; r < 8; r++) {
+        *dst->data(row + r, col + c) = buf[r + 8 * c];
+      }
+    }
+  }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
+  static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
+                  int col) {
+    std::uint8_t buf[64];
+    StoreUint8x16(buf, src.buf.reg[0]);
+    StoreUint8x16(buf + 16, src.buf.reg[1]);
+    StoreUint8x16(buf + 32, src.buf.reg[2]);
+    StoreUint8x16(buf + 48, src.buf.reg[3]);
+    for (int c = 0; c < 8; c++) {
+      for (int r = 0; r < 8; r++) {
+        *dst->data(row + r, col + c) = buf[r + 8 * c];
+      }
+    }
+  }
+};
+
+}  // namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
diff --git a/internal/pack.h b/internal/pack.h
index 4531f79..3395396 100644
--- a/internal/pack.h
+++ b/internal/pack.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -29,7 +29,6 @@
 
 #include <cstring>
 
-#include "../public/bit_depth.h"
 #include "allocator.h"
 #include "block_params.h"
 #include "common.h"
@@ -51,8 +50,7 @@
 
   PackedSideBlock(Side side, Allocator* allocator,
                   const BlockParams& block_params)
-      : allocator_(allocator),
-        pos_(0) {
+      : allocator_(allocator), pos_(0) {
     GetSideBlockParams(side, &params_, block_params);
     data_handle_ =
         allocator_->Reserve<std::uint8_t>(params_.l2_width * params_.l2_depth);
@@ -189,94 +187,6 @@
   int width_, depth_, stride_;
 };
 
-template <RoundingMode tRoundingMode>
-class ScalarRoundingOffsetGenerator {
- public:
-  std::uint8_t get() {
-    assert(false);  // This generic path should never be called.
-    return 0;
-  }
-};
-
-// A RoundingOffsetGenerator for rounding-to-nearest, always returning
-// the midpoint value 127.
-template <>
-class ScalarRoundingOffsetGenerator<RoundingMode::Nearest> {
- public:
-  std::uint8_t get() { return 127; }
-};
-
-// A RoundingOffsetGenerator based on a 8-bit Xorshift.
-// This gives good results as Xorshift naturally generates
-// uniform random *nonzero* bytes i.e. 255 different values,
-// so it only remains for us to subtract one.
-template <>
-class ScalarRoundingOffsetGenerator<RoundingMode::ProbabilisticXorshift> {
- public:
-  ScalarRoundingOffsetGenerator() { x_ = 128; }
-
-  std::uint8_t get() {
-    std::uint8_t result = x_ - 1;
-    // Xorshift8(7,5,3)
-    x_ ^= x_ << 7;
-    x_ ^= x_ >> 5;
-    x_ ^= x_ << 3;
-    return result;
-  }
-
- private:
-  // State
-  std::uint8_t x_;
-};
-
-// A RoundingOffsetGenerator based on an 8-bit add/mod
-// low-discrepancy sequence.  See less-than-8-bit.txt for
-// an explanation (the constant 97 is important - it must
-// be both relatively prime to 255, in order for the sequence
-// to be full-period, and c/255 should be close to 0.38 to
-// obtain low discrepancy).  Uses a small bit hack to avoid
-// expensive % operations.
-template <>
-class ScalarRoundingOffsetGenerator<RoundingMode::ProbabilisticAddmod> {
-  static const uint8_t AddConst = 97;
-
- public:
-  ScalarRoundingOffsetGenerator() { x_ = 1; }  // Start must be non-zero
-
-  std::uint8_t get() {
-    // The +'d boolean term causes the increment to skip over 255,
-    // (recalling that 255+1 = 256 = 0 for an 8 bit uint),
-    // thus implementing %255
-    x_ += (AddConst + (x_ >= (255 - AddConst)));
-    return x_;
-  }
-
- private:
-  // State
-  std::uint8_t x_;
-};
-
-// Requantizes a source uint8 value in [0..255] range
-// to the range specified by BitDepth, [0..((2^bits)-1)].
-// Bias must be avoided. Currently this is achieved
-// by probabilistic rounding.
-template <typename QuantizationParams>
-std::uint8_t Requantize(
-    std::uint8_t raw_src_val,
-    ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>*
-        rounding_offset_generator) {
-  static const int kBits = QuantizationParams::BitDepth::kBits;
-  static const std::uint8_t kMaxVal = (1 << kBits) - 1;
-
-  if (kBits == 8) {
-    return raw_src_val;
-  }
-
-  std::uint16_t scaled = static_cast<std::uint16_t>(raw_src_val) * kMaxVal;
-  std::uint8_t rounding_offset = rounding_offset_generator->get();
-  return (scaled + rounding_offset) / 255;
-}
-
 // A PackingRegisterBlock is a small fixed-size block of a matrix being
 // packed. This class is the generic non-optimized implementation,
 // it is inherited by the generic implementation of PackingRegisterBlock,
@@ -293,21 +203,20 @@
 //   2. Packing a complete block into the destination, see Pack. This is the
 //      most critical part, so it's convenient that unaligned boundaries have
 //      already been handled in step 1.
-template <typename QuantizationParams, typename SrcMapType,
-          typename PackedSideBlock>
+template <typename SrcMapType, typename PackedSideBlock>
 class PackingRegisterBlockBase {
  public:
   typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat;
   typedef typename KernelSideFormat::Cell CellFormat;
+  typedef typename KernelSideFormat::Scalar KernelScalar;
   static const int kCells = KernelSideFormat::kCells;
   static const int kCellWidth = CellFormat::kWidth;
   static const int kKernelWidth = CellFormat::kWidth * kCells;
   static const int kCellDepth = CellFormat::kDepth;
   static const int kCellSize = CellFormat::kSize;
   static const SideMapOrder kSrcOrder = SrcMapType::kOrder;
-
-  typedef ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>
-      RoundingOffsetGenerator;
+  static const int kZeroPointInputValue =
+      ZeroPointInputValue<KernelScalar>::kValue;
 
   PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {}
 
@@ -329,7 +238,7 @@
   // Copies an incomplete block of source data into a local temporary
   // complete block by zero-extending it.
   void MakeCompleteSrc(const SrcMapType& src) {
-    memset(buf_, 0, kKernelWidth * kRegisterSize);
+    memset(buf_, kZeroPointInputValue, kKernelWidth * kRegisterSize);
     if (kSrcOrder == SideMapOrder::WidthMajor) {
       for (int w = 0; w < src.width(); w++) {
         memcpy(buf_ + w * kRegisterSize, src.data(w, 0), src.depth());
@@ -345,8 +254,7 @@
   // Packs a complete block into the destination. This is the most
   // critical part and the part that we most typically want to
   // override in architecture-specific optimized specializations.
-  void Pack(PackedSideBlock* dst, int start_width,
-            RoundingOffsetGenerator* rounding_offset_generator) {
+  void Pack(PackedSideBlock* dst, int start_width) {
     std::uint8_t* dst_ptr = dst->current_data();
     for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
          cell_start_depth += kCellDepth) {
@@ -360,11 +268,12 @@
         for (int w = 0; w < kCellWidth; w++) {
           std::int32_t sum = 0;
           for (int d = 0; d < kCellDepth; d++) {
-            const std::uint8_t raw_src_val = src_cell_map(w, d);
-            const std::uint8_t requantized = Requantize<QuantizationParams>(
-                raw_src_val, rounding_offset_generator);
-            dst_ptr[OffsetIntoCell<CellFormat>(w, d)] = requantized;
-            sum += requantized;
+            const std::uint8_t src_val = src_cell_map(w, d);
+            const std::int16_t kernel_val_unwrapped =
+                src_val - kZeroPointInputValue;
+            const std::uint8_t kernel_val_uint8 = kernel_val_unwrapped;
+            dst_ptr[OffsetIntoCell<CellFormat>(w, d)] = kernel_val_uint8;
+            sum += kernel_val_unwrapped;
           }
           cell_sums_of_each_slice_ptr[w] += sum;
         }
@@ -375,15 +284,12 @@
   }
 };
 
-template <typename QuantizationParams, typename SrcMapType,
-          typename PackedSideBlock>
+template <typename SrcMapType, typename PackedSideBlock>
 class PackingRegisterBlock
-    : public PackingRegisterBlockBase<QuantizationParams, SrcMapType,
-                                      PackedSideBlock> {};
+    : public PackingRegisterBlockBase<SrcMapType, PackedSideBlock> {};
 
 // Large-scale implementation of packing.
-template <typename QuantizationParams, typename SrcMapType,
-          typename PackedSideBlock>
+template <typename SrcMapType, typename PackedSideBlock>
 class PackSideBlockImpl {
  public:
   typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat;
@@ -393,10 +299,8 @@
   static const int kKernelWidth = CellFormat::kWidth * kCells;
   static const int kCellDepth = CellFormat::kDepth;
 
-  typedef PackingRegisterBlock<QuantizationParams, SrcMapType, PackedSideBlock>
+  typedef PackingRegisterBlock<SrcMapType, PackedSideBlock>
       PackingRegisterBlockType;
-  typedef typename PackingRegisterBlockType::RoundingOffsetGenerator
-      RoundingOffsetGenerator;
 
   PackSideBlockImpl(PackedSideBlock* packed_side_block,
                     const SrcMapType& src_map)
@@ -462,14 +366,14 @@
         for (int d = 0; d < register_aligned_depth; d += kRegisterSize) {
           b.UseCompleteSrcInPlace(src_map_.block(start_width, start_depth + d,
                                                  width, kRegisterSize));
-          b.Pack(packed_side_block_, start_width, &rounding_offset_generator_);
+          b.Pack(packed_side_block_, start_width);
         }
       }
       if (register_aligned_depth < depth) {
         b.MakeCompleteSrc(
             src_map_.block(start_width, start_depth + register_aligned_depth,
                            width, depth - register_aligned_depth));
-        b.Pack(packed_side_block_, start_width, &rounding_offset_generator_);
+        b.Pack(packed_side_block_, start_width);
       }
     } else {
       assert(width < kKernelWidth);
@@ -477,7 +381,7 @@
         const int ds = std::min(+kRegisterSize, depth - d);
         b.MakeCompleteSrc(
             src_map_.block(start_width, start_depth + d, width, ds));
-        b.Pack(packed_side_block_, start_width, &rounding_offset_generator_);
+        b.Pack(packed_side_block_, start_width);
       }
     }
   }
@@ -488,24 +392,10 @@
   // A map on the block of the original matrix block being packed,
   // i.e. the 'source'.
   const SrcMapType& src_map_;
-
-  // Used for requantization in the less-than-8-bit case.
-  // Otherwise unused.
-  RoundingOffsetGenerator rounding_offset_generator_;
-};
-
-// Quantization parameters for the side (LHS or RHS) being packed,
-// with the rounding strategy having been already resolved to a specific
-// rounding mode.
-template <typename tBitDepth, RoundingMode tRoundingMode>
-struct QuantizationParams {
-  typedef tBitDepth BitDepth;
-  static const RoundingMode kRoundingMode = tRoundingMode;
 };
 
 // Packs a block of the input LHS matrix, into a PackedSideBlock
-template <typename BitDepthParams, typename PackedSideBlock,
-          typename MatrixMapType>
+template <typename PackedSideBlock, typename MatrixMapType>
 void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) {
   ScopedProfilingLabel label("pack LHS");
   static const SideMapOrder kSideMapOrder =
@@ -514,29 +404,13 @@
   typedef typename MatrixMapType::Scalar Scalar;
   typedef SideMap<Scalar, kSideMapOrder> SideMapType;
   SideMapType src_side_map(src.data(), src.rows(), src.cols(), src.stride());
-  typedef typename BitDepthParams::LhsBitDepth BitDepth;
-  typedef typename BitDepthParams::RoundingStrategy RoundingStrategy;
-  const int accumulation_depth = src_side_map.depth();
-  if (accumulation_depth < RoundingStrategy::kRoundingModeSizeThreshold) {
-    typedef QuantizationParams<BitDepth,
-                               RoundingStrategy::kRoundingModeForSmallSizes>
-        QParams;
-    typedef PackSideBlockImpl<QParams, SideMapType, PackedSideBlock> ImplType;
-    ImplType impl(dst, src_side_map);
-    impl.PackL2();
-  } else {
-    typedef QuantizationParams<BitDepth,
-                               RoundingStrategy::kRoundingModeForLargeSizes>
-        QParams;
-    typedef PackSideBlockImpl<QParams, SideMapType, PackedSideBlock> ImplType;
-    ImplType impl(dst, src_side_map);
-    impl.PackL2();
-  }
+  typedef PackSideBlockImpl<SideMapType, PackedSideBlock> ImplType;
+  ImplType impl(dst, src_side_map);
+  impl.PackL2();
 }
 
 // Packs a block of the input RHS matrix, into a PackedSideBlock
-template <typename BitDepthParams, typename PackedSideBlock,
-          typename MatrixMapType>
+template <typename PackedSideBlock, typename MatrixMapType>
 void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) {
   ScopedProfilingLabel label("pack RHS");
   static const SideMapOrder kSideMapOrder =
@@ -545,24 +419,9 @@
   typedef typename MatrixMapType::Scalar Scalar;
   typedef SideMap<Scalar, kSideMapOrder> SideMapType;
   SideMapType src_side_map(src.data(), src.cols(), src.rows(), src.stride());
-  typedef typename BitDepthParams::RhsBitDepth BitDepth;
-  typedef typename BitDepthParams::RoundingStrategy RoundingStrategy;
-  const int accumulation_depth = src_side_map.depth();
-  if (accumulation_depth < RoundingStrategy::kRoundingModeSizeThreshold) {
-    typedef QuantizationParams<BitDepth,
-                               RoundingStrategy::kRoundingModeForSmallSizes>
-        QParams;
-    typedef PackSideBlockImpl<QParams, SideMapType, PackedSideBlock> ImplType;
-    ImplType impl(dst, src_side_map);
-    impl.PackL2();
-  } else {
-    typedef QuantizationParams<BitDepth,
-                               RoundingStrategy::kRoundingModeForLargeSizes>
-        QParams;
-    typedef PackSideBlockImpl<QParams, SideMapType, PackedSideBlock> ImplType;
-    ImplType impl(dst, src_side_map);
-    impl.PackL2();
-  }
+  typedef PackSideBlockImpl<SideMapType, PackedSideBlock> ImplType;
+  ImplType impl(dst, src_side_map);
+  impl.PackL2();
 }
 
 }  // namespace gemmlowp
@@ -570,7 +429,7 @@
 #ifdef GEMMLOWP_NEON
 #include "pack_neon.h"
 #elif defined(GEMMLOWP_SSE4)
-#include "pack_SSE.h"
+#include "pack_sse.h"
 #endif
 
 #endif  // GEMMLOWP_INTERNAL_PACK_H_
diff --git a/internal/pack_neon.h b/internal/pack_neon.h
index 4936b49..e212d07 100644
--- a/internal/pack_neon.h
+++ b/internal/pack_neon.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -23,151 +23,19 @@
 
 namespace gemmlowp {
 
-template <RoundingMode tRoundingMode>
-class NEONRoundingOffsetGenerator {
- public:
-  uint8x16_t get() {
-    assert(false);  // This generic path should never be called.
-    return vdupq_n_u8(0);
-  }
-};
-
-// A RoundingOffsetGenerator for rounding-to-nearest, always returning
-// the midpoint value 127.
-template <>
-class NEONRoundingOffsetGenerator<RoundingMode::Nearest> {
- public:
-  uint8x16_t get() { return vdupq_n_u8(127); }
-};
-
-// Variant of NEONRoundingOffsetGenerator that produces
-// random NEON 128-bit vectors using a 8-bit Xorshift.
-template <>
-class NEONRoundingOffsetGenerator<RoundingMode::ProbabilisticXorshift> {
- public:
-  NEONRoundingOffsetGenerator() {
-    uint8_t s = 128;
-    std::uint8_t a[16];
-    for (int i = 0; i < 16; i++) {
-      a[i] = s;
-      // Xorshift8(7,7,1). Very important to choose a different
-      // xorshift than we do in get(), otherwise lanes would contain
-      // the same values!
-      s ^= s << 7;
-      s ^= s >> 7;
-      s ^= s << 1;
-    }
-    x_ = vld1q_u8(a);
-  }
-
-  uint8x16_t get() {
-    // Xorshift produces values in [1..255], we want [0..254].
-    uint8x16_t result = vsubq_u8(x_, vdupq_n_u8(1));
-    // Xorshift8(7,5,3)
-    x_ = veorq_u8(x_, vshlq_n_u8(x_, 7));
-    x_ = veorq_u8(x_, vshrq_n_u8(x_, 5));
-    x_ = veorq_u8(x_, vshlq_n_u8(x_, 3));
-    return result;
-  }
-
- private:
-  // State
-  uint8x16_t x_;
-};
-
-// Variant of NEONRoundingOffsetGenerator that produces
-// rounding vectors using an 8-bit add/mod low-discrepancy sequence.
-template <>
-class NEONRoundingOffsetGenerator<RoundingMode::ProbabilisticAddmod> {
- public:
-  NEONRoundingOffsetGenerator() {
-    uint8_t s = 128;
-    std::uint8_t a[16];
-    // The initial offset is set by offsetting each lane to one
-    // more iteration of the sequence (s0...s15)  Then, upon iteration,
-    // each lane moves ahead by 16.
-    for (int i = 0; i < 16; i++) {
-      a[i] = s;
-      s += (97 + (s >= 158));
-    }
-    x_ = vld1q_u8(a);
-  }
-
-  uint8x16_t get() {
-    // Get moves the lane ahead by 16 iterations of the sequence
-    // x_ = (x + (16*97)) % 255.  (16*97)%255 = 22.  255-22=233,
-    // so x_ += (22 + (x >= 233)).
-    // There's an excessively opaque bit hack here:
-    // A "true" compare on NEON produces an all-1s result (0xff).
-    // So instead of adding in the comparison result, we subtract it
-    // to get the same effect as adding 1.
-    uint8x16_t extra_one = vcgeq_u8(x_, vdupq_n_u8(233));
-    x_ = vaddq_u8(x_, vdupq_n_u8(22));
-    x_ = vsubq_u8(x_, extra_one);
-    return x_;
-  }
-
- private:
-  // State
-  uint8x16_t x_;
-};
-
-// Requantizes source uint8 values in [0..255] range
-// to the range specified by BitDepth, [0..((2^bits)-1)].
-// Bias must be avoided. Currently this is achieved
-// by probabilistic rounding.
-template <typename QuantizationParams>
-uint8x16_t Requantize(
-    uint8x16_t raw_src_data,
-    NEONRoundingOffsetGenerator<QuantizationParams::kRoundingMode>*
-        rounding_offset_generator) {
-  static const int kBits = QuantizationParams::BitDepth::kBits;
-  static const std::uint8_t kMaxVal = (1 << kBits) - 1;
-
-  if (kBits == 8) {
-    return raw_src_data;
-  }
-
-  uint8x16_t rounding_offset = rounding_offset_generator->get();
-
-  // Compute:
-  //   x = maxval * src + rounding_offset
-  uint16x8_t x[2];
-  const uint8x8_t maxval_dup = vdup_n_u8(kMaxVal);
-  x[0] = vmlal_u8(vmovl_u8(vget_low_u8(rounding_offset)), maxval_dup,
-                  vget_low_u8(raw_src_data));
-  x[1] = vmlal_u8(vmovl_u8(vget_high_u8(rounding_offset)), maxval_dup,
-                  vget_high_u8(raw_src_data));
-
-  // Divide by 255 (truncating).
-  //
-  // Here we use the following formula, valid for all integers y in 0..65534
-  // (which is more than we need since we've already early-returned
-  // if kBits==8).
-  //
-  //     y/255 = (y + 1 + (y >> 8)) >> 8.
-  uint8x8_t result[2];
-  for (int i = 0; i < 2; i++) {
-    result[i] = vshrn_n_u16(
-        vaddq_u16(vaddq_u16(x[i], vdupq_n_u16(1)), vshrq_n_u16(x[i], 8)), 8);
-  }
-
-  return vcombine_u8(result[0], result[1]);
-}
-
 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
     WidthMajorUint8SideMap;
 
 template <int Cells>
 using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
 
-template <typename QuantizationParams, int Cells>
+template <int Cells>
 class PackingRegisterBlock<
-    QuantizationParams, WidthMajorUint8SideMap,
-    PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells> > >
+    WidthMajorUint8SideMap,
+    PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>>
     : public PackingRegisterBlockBase<
-          QuantizationParams, WidthMajorUint8SideMap,
-          PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells> > > {
+          WidthMajorUint8SideMap,
+          PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> {
  public:
   typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
   typedef typename KernelSideFormat::Cell CellFormat;
@@ -177,19 +45,14 @@
   static const int kCellDepth = CellFormat::kDepth;
   static const int kCellSize = CellFormat::kSize;
 
-  typedef NEONRoundingOffsetGenerator<QuantizationParams::kRoundingMode>
-      RoundingOffsetGenerator;
-
-  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width,
-            RoundingOffsetGenerator* rounding_offset_generator) {
+  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
     std::uint8_t* dst_ptr = dst->current_data();
     const std::uint8_t* const src_ptr = this->complete_src_.data();
     const int stride = this->complete_src_.stride();
-    // Load and requantize source WidthMajor data
+    // Load source WidthMajor data
     uint8x16_t src_lines[4 * kCells];
     for (int i = 0; i < 4 * kCells; i++) {
-      src_lines[i] = Requantize<QuantizationParams>(
-          vld1q_u8(src_ptr + i * stride), rounding_offset_generator);
+      src_lines[i] = vld1q_u8(src_ptr + i * stride);
     }
     // Reorder the data within registers to make DepthMajor 4x2 cells
     uint8x16x2_t src_lines_intertwined_2x[2 * kCells];
@@ -267,13 +130,13 @@
 using WidthMajorSideFormatNCells4x2 =
     KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
 
-template <typename QuantizationParams, int Cells>
+template <int Cells>
 class PackingRegisterBlock<
-    QuantizationParams, WidthMajorUint8SideMap,
-    PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > >
+    WidthMajorUint8SideMap,
+    PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
     : public PackingRegisterBlockBase<
-          QuantizationParams, WidthMajorUint8SideMap,
-          PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > {
+          WidthMajorUint8SideMap,
+          PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
  public:
   typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
   typedef typename KernelSideFormat::Cell CellFormat;
@@ -283,15 +146,11 @@
   static const int kCellDepth = CellFormat::kDepth;
   static const int kCellSize = CellFormat::kSize;
 
-  typedef NEONRoundingOffsetGenerator<QuantizationParams::kRoundingMode>
-      RoundingOffsetGenerator;
-
-  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width,
-            RoundingOffsetGenerator* rounding_offset_generator) {
+  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
     std::uint8_t* dst_ptr = dst->current_data();
     const std::uint8_t* src_ptr = this->complete_src_.data();
     const int stride = this->complete_src_.stride();
-    // Load and requantize source WidthMajor data
+    // Load source WidthMajor data
     uint16x8_t src_lines[kCells * 4];
     for (int i = 0; i < kCells; i++) {
 // This packing path is used with our current
@@ -299,9 +158,8 @@
 // results in substantially faster code (thanks to better
 // register allocation) on Nexus 5.
 
-#define GEMMLOWP_UNROLLED_LOOP_ITER(k)                                        \
-  src_lines[4 * i + k] = vreinterpretq_u16_u8(Requantize<QuantizationParams>( \
-      vld1q_u8(src_ptr), rounding_offset_generator));                         \
+#define GEMMLOWP_UNROLLED_LOOP_ITER(k)                            \
+  src_lines[4 * i + k] = vreinterpretq_u16_u8(vld1q_u8(src_ptr)); \
   src_ptr += stride;
 
       GEMMLOWP_UNROLLED_LOOP_ITER(0)
@@ -385,6 +243,78 @@
   }
 };
 
+#ifdef GEMMLOWP_NEON_32
+inline int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+  const int16x4_t c = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+  const int16x4_t d = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+  return vcombine_s16(c, d);
+}
+#endif
+
+template <int Width>
+using Int8FastKernelFormat =
+    KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
+
+template <int Width>
+class PackingRegisterBlock<WidthMajorUint8SideMap,
+                           PackedSideBlock<Int8FastKernelFormat<Width>>>
+    : public PackingRegisterBlockBase<
+          WidthMajorUint8SideMap,
+          PackedSideBlock<Int8FastKernelFormat<Width>>> {
+ public:
+  static_assert(Width == 2 || Width == 4, "");
+  typedef Int8FastKernelFormat<Width> KernelSideFormat;
+  typedef typename KernelSideFormat::Cell CellFormat;
+  static const int kCells = KernelSideFormat::kCells;
+  static const int kCellWidth = CellFormat::kWidth;
+  static const int kKernelWidth = CellFormat::kWidth * kCells;
+  static const int kCellDepth = CellFormat::kDepth;
+  static const int kCellSize = CellFormat::kSize;
+
+  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+    std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
+    std::uint8_t* dst_ptr = dst->current_data();
+    const std::uint8_t* const src_ptr = this->complete_src_.data();
+    const int stride = this->complete_src_.stride();
+    // Load source WidthMajor data
+    uint8x16_t src_lines[Width];
+    for (int i = 0; i < Width; i++) {
+      src_lines[i] = vld1q_u8(src_ptr + i * stride);
+    }
+    const uint8x16_t sign_bit_dup = vdupq_n_u8(0x80);
+    for (int i = 0; i < Width; i++) {
+      src_lines[i] = veorq_u8(src_lines[i], sign_bit_dup);
+    }
+    for (int i = 0; i < Width; i++) {
+      vst1q_u8(dst_ptr + 16 * i, src_lines[i]);
+    }
+    int16x8_t sums2[Width];
+    for (int i = 0; i < Width; i++) {
+      const int8x8_t lo = vreinterpret_s8_u8(vget_low_u8(src_lines[i]));
+      const int8x8_t hi = vreinterpret_s8_u8(vget_high_u8(src_lines[i]));
+      sums2[i] = vaddl_s8(lo, hi);
+    }
+    int16x8_t sums4[Width / 2];
+    for (int i = 0; i < Width / 2; i++) {
+      sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]);
+    }
+    if (Width == 4) {
+      int32x4_t sum = vld1q_s32(sums_ptr);
+      int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]);
+      sum = vpadalq_s16(sum, sums8);
+      vst1q_s32(sums_ptr, sum);
+    } else {
+      assert(Width == 2);
+      int32x2_t sum = vld1_s32(sums_ptr);
+      int16x4_t sums8 =
+          vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0]));
+      sum = vpadal_s16(sum, sums8);
+      vst1_s32(sums_ptr, sum);
+    }
+    dst->seek_forward_n_cells(1);
+  }
+};
+
 }  // namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_PACK_NEON_H_
diff --git a/internal/pack_sse.h b/internal/pack_sse.h
new file mode 100644
index 0000000..52163c4
--- /dev/null
+++ b/internal/pack_sse.h
@@ -0,0 +1,128 @@
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// pack_SSE.h: optimized SSE specializations of the templates in pack.h.
+
+#ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_
+#define GEMMLOWP_INTERNAL_PACK_SSE_H_
+
+#include <smmintrin.h>
+#include "pack.h"
+
+namespace gemmlowp {
+
+// TODO: Add DepthMajorUint8SideMap
+
+typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
+    WidthMajorUint8SideMap;
+
+template <int Cells>
+using WidthMajorSideFormatNCells4x2 =
+    KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+    WidthMajorUint8SideMap,
+    PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > >
+    : public PackingRegisterBlockBase<
+          WidthMajorUint8SideMap,
+          PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > {
+ public:
+  typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+  typedef typename KernelSideFormat::Cell CellFormat;
+  static const int kCells = KernelSideFormat::kCells;
+  static const int kCellWidth = CellFormat::kWidth;
+  static const int kKernelWidth = CellFormat::kWidth * kCells;
+  static const int kCellDepth = CellFormat::kDepth;
+  static const int kCellSize = CellFormat::kSize;
+
+  void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+    std::uint8_t* dst_ptr = dst->current_data();
+    const int width_stride = this->complete_src_.width_stride();
+    int depth_step = 8;
+
+    __m128i one = _mm_set1_epi16(1);
+    for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
+         cell_start_depth += depth_step) {
+      for (int cell_start_width = 0; cell_start_width < kKernelWidth;
+           cell_start_width += kCellWidth) {
+        std::int32_t* cell_sums_of_each_slice_ptr =
+            dst->sums_of_each_slice() + start_width + cell_start_width;
+        const std::uint8_t* src_data =
+            this->complete_src_.data(cell_start_width, cell_start_depth);
+
+        __m128i xmm1 =
+            _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0]));
+        __m128i xmm2 = _mm_loadl_epi64(
+            reinterpret_cast<const __m128i*>(&src_data[1 * width_stride]));
+        __m128i xmm3 = _mm_loadl_epi64(
+            reinterpret_cast<const __m128i*>(&src_data[2 * width_stride]));
+        __m128i xmm4 = _mm_loadl_epi64(
+            reinterpret_cast<const __m128i*>(&src_data[3 * width_stride]));
+
+        __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
+        __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
+
+        __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
+        __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
+
+        __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
+        __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
+
+        _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9);
+        _mm_storel_epi64(
+            reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10);
+
+        __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
+        __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
+
+        _mm_storel_epi64(
+            reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]),
+            xmm11);
+        _mm_storel_epi64(
+            reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]),
+            xmm12);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm9);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
+            reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0]));
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm10);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm11);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        xmm1 = _mm_cvtepu8_epi16(xmm12);
+        xmm2 = _mm_madd_epi16(xmm1, one);
+        sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+        _mm_storeu_si128(
+            reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]),
+            sums_of_each_slice_xmm);
+        dst_ptr += kCellSize;
+      }
+      dst_ptr += 3 * kCellSize * kCells;
+    }
+    dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+  }
+};
+
+}  // namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_PACK_SSE_H_
diff --git a/internal/simd_wrappers.h b/internal/simd_wrappers.h
new file mode 100644
index 0000000..e39eaf8
--- /dev/null
+++ b/internal/simd_wrappers.h
@@ -0,0 +1,508 @@
+// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
+// extending the set of such functions from fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
+#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
+
+#include <algorithm>
+#include <type_traits>
+#include "../fixedpoint/fixedpoint.h"
+
+namespace gemmlowp {
+
+template <typename ScalarType, int ScalarCount>
+struct RegisterType {
+  using Type = ScalarType;
+};
+
+inline std::int32_t Min(std::int32_t a, std::int32_t b) {
+  return std::min(a, b);
+}
+
+inline std::int32_t Max(std::int32_t a, std::int32_t b) {
+  return std::max(a, b);
+}
+
+inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
+  *acc += lhs * rhs;
+}
+
+template <typename tScalarType, int tScalarCount>
+struct RegisterBuffer {
+  using ScalarType = tScalarType;
+  static constexpr int kScalarCount = tScalarCount;
+  using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
+  static_assert((kScalarCount & (kScalarCount - 1)) == 0,
+                "kScalarCount must be a power of two");
+  static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
+  static constexpr int kRegisterLanes =
+      sizeof(RegisterType) / sizeof(ScalarType);
+  static constexpr int kRegisterCount =
+      (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
+      sizeof(RegisterType);
+
+  RegisterType reg[kRegisterCount];
+};
+
+template <typename tScalarType, int tRows, int tCols>
+struct RegisterBlock {
+  using ScalarType = tScalarType;
+  static constexpr int kRows = tRows;
+  static constexpr int kCols = tCols;
+  static constexpr int kScalarCount = kRows * kCols;
+  using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
+  using RegisterType = typename BufferType::RegisterType;
+  static constexpr int kRegisterCount = BufferType::kRegisterCount;
+  static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
+
+  BufferType buf;
+};
+
+template <typename RegisterBlockType>
+struct RegisterBlockAddImpl {
+  static RegisterBlockType Run(const RegisterBlockType& lhs,
+                               const RegisterBlockType& rhs) {
+    RegisterBlockType result;
+    for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
+      result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
+    }
+    return result;
+  }
+};
+
+template <typename RegisterBlockType>
+RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
+                                   const RegisterBlockType& rhs) {
+  return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
+}
+
+template <typename LhsType, typename RhsType>
+struct ShouldFlipLhsRhs {
+  static constexpr bool kValue =
+      (LhsType::kScalarCount < RhsType::kScalarCount) ||
+      (LhsType::kScalarCount == RhsType::kScalarCount &&
+       (LhsType::kRows < RhsType::kRows));
+};
+
+template <typename LhsType, typename RhsType,
+          bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
+struct FlipLhsRhs {
+  using FlippedLhsType = LhsType;
+  using FlippedRhsType = RhsType;
+  static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
+                                          const RhsType& rhs) {
+    return lhs;
+  }
+  static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
+                                          const RhsType& rhs) {
+    return rhs;
+  }
+};
+
+template <typename LhsType, typename RhsType>
+struct FlipLhsRhs<LhsType, RhsType, true> {
+  using FlippedLhsType = RhsType;
+  using FlippedRhsType = LhsType;
+  static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
+                                          const RhsType& rhs) {
+    return rhs;
+  }
+  static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
+                                          const RhsType& rhs) {
+    return lhs;
+  }
+};
+
+template <typename Lhs, typename Rhs>
+struct BroadcastBinaryOpShape {
+  static constexpr int kRows =
+      Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
+  static constexpr int kCols =
+      Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
+};
+
+template <typename Lhs, typename Rhs>
+struct BroadcastBinaryOpRegisterBlock {
+  using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
+  using ScalarType = typename Lhs::ScalarType;
+  using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
+};
+
+template <typename Lhs, typename Rhs>
+struct BroadcastAddImpl {
+  using ResultBlockType =
+      typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+  static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+    ResultBlockType result;
+    static constexpr int Rows = ResultBlockType::kRows;
+    static constexpr int Cols = ResultBlockType::kCols;
+    static constexpr int LhsRows = Lhs::kRows;
+    static constexpr int LhsCols = Lhs::kCols;
+    static constexpr int RhsRows = Rhs::kRows;
+    static constexpr int RhsCols = Rhs::kCols;
+
+    static_assert(LhsRows == Rows || LhsRows == 1, "");
+    static_assert(RhsRows == Rows || RhsRows == 1, "");
+    static_assert(LhsCols == Cols || LhsCols == 1, "");
+    static_assert(RhsCols == Cols || RhsCols == 1, "");
+    static_assert(ResultBlockType::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Lhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Rhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+
+    for (int c = 0; c < Cols; c++) {
+      const int lhs_c = LhsCols == Cols ? c : 0;
+      const int rhs_c = RhsCols == Cols ? c : 0;
+      for (int r = 0; r < Rows; r++) {
+        const int lhs_r = LhsRows == Rows ? r : 0;
+        const int rhs_r = RhsRows == Rows ? r : 0;
+        result.buf.reg[r + c * Rows] =
+            Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+                rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+      }
+    }
+    return result;
+  }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
+    const Lhs& lhs, const Rhs& rhs) {
+  using Flip = FlipLhsRhs<Lhs, Rhs>;
+  return BroadcastAddImpl<
+      typename Flip::FlippedLhsType,
+      typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+                                          Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
+struct BroadcastMulImpl {
+  using ResultBlockType =
+      typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+  static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+    ResultBlockType result;
+    static constexpr int Rows = ResultBlockType::kRows;
+    static constexpr int Cols = ResultBlockType::kCols;
+    static constexpr int LhsRows = Lhs::kRows;
+    static constexpr int LhsCols = Lhs::kCols;
+    static constexpr int RhsRows = Rhs::kRows;
+    static constexpr int RhsCols = Rhs::kCols;
+    static_assert(ResultBlockType::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Lhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Rhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+
+    static_assert(LhsRows == Rows || LhsRows == 1, "");
+    static_assert(RhsRows == Rows || RhsRows == 1, "");
+    static_assert(LhsCols == Cols || LhsCols == 1, "");
+    static_assert(RhsCols == Cols || RhsCols == 1, "");
+    for (int c = 0; c < Cols; c++) {
+      const int lhs_c = LhsCols == Cols ? c : 0;
+      const int rhs_c = RhsCols == Cols ? c : 0;
+      for (int r = 0; r < Rows; r++) {
+        const int lhs_r = LhsRows == Rows ? r : 0;
+        const int rhs_r = RhsRows == Rows ? r : 0;
+        result.buf.reg[r + c * Rows] =
+            Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+                rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+      }
+    }
+    return result;
+  }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
+    const Lhs& lhs, const Rhs& rhs) {
+  using Flip = FlipLhsRhs<Lhs, Rhs>;
+  return BroadcastMulImpl<
+      typename Flip::FlippedLhsType,
+      typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+                                          Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs, typename Acc>
+struct BroadcastMulAddImpl {
+  static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
+    static constexpr int Rows = Acc::kRows;
+    static constexpr int Cols = Acc::kCols;
+    static constexpr int LhsRows = Lhs::kRows;
+    static constexpr int LhsCols = Lhs::kCols;
+    static constexpr int RhsRows = Rhs::kRows;
+    static constexpr int RhsCols = Rhs::kCols;
+    static_assert(Acc::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Lhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+    static_assert(Rhs::kRegisterLanes == 1,
+                  "This path is only for scalar values");
+
+    static_assert(LhsRows == Rows || LhsRows == 1, "");
+    static_assert(RhsRows == Rows || RhsRows == 1, "");
+    static_assert(LhsCols == Cols || LhsCols == 1, "");
+    static_assert(RhsCols == Cols || RhsCols == 1, "");
+    for (int c = 0; c < Cols; c++) {
+      const int lhs_c = LhsCols == Cols ? c : 0;
+      const int rhs_c = RhsCols == Cols ? c : 0;
+      for (int r = 0; r < Rows; r++) {
+        const int lhs_r = LhsRows == Rows ? r : 0;
+        const int rhs_r = RhsRows == Rows ? r : 0;
+        MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+               rhs.buf.reg[rhs_r + rhs_c * RhsRows],
+               &acc->buf.reg[r + c * Rows]);
+      }
+    }
+  }
+};
+
+template <typename Lhs, typename Rhs, typename Acc>
+void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
+  using Flip = FlipLhsRhs<Lhs, Rhs>;
+  BroadcastMulAddImpl<typename Flip::FlippedLhsType,
+                      typename Flip::FlippedRhsType,
+                      Acc>::Run(Flip::FlippedLhs(lhs, rhs),
+                                Flip::FlippedRhs(lhs, rhs), acc);
+}
+
+template <typename RegisterBlockType, typename SrcObjectType>
+struct LoadImpl {
+  static_assert(std::is_same<SrcObjectType, void>::value,
+                "This generic impl should never be hit");
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
+struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
+                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
+  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+  using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
+  static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
+    RegisterBlockType result;
+    int i = 0;
+    for (int c = 0; c < Cols; c++) {
+      const ScalarType* src_ptr = src.data(row, col + c);
+      for (int r = 0; r < Rows; r++) {
+        result.buf.reg[i++] = *src_ptr++;
+      }
+    }
+    return result;
+  }
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+          VectorShape Shape>
+struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
+                VectorMap<SrcScalarType, Shape>> {
+  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+  using SrcObjectType = VectorMap<SrcScalarType, Shape>;
+  static RegisterBlockType Run(const SrcObjectType& src, int pos) {
+    static_assert(Shape == VectorShape::Col || Rows == 1, "");
+    static_assert(Shape == VectorShape::Row || Cols == 1, "");
+    RegisterBlockType result;
+    for (int i = 0; i < Rows * Cols; i++) {
+      result.buf.reg[i] = src(pos + i);
+    }
+    return result;
+  }
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+          VectorShape Shape>
+struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
+                VectorDup<SrcScalarType, Shape>> {
+  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+  using SrcObjectType = VectorDup<SrcScalarType, Shape>;
+  static RegisterBlockType Run(const SrcObjectType& src, int) {
+    static_assert(Shape == VectorShape::Col || Rows == 1, "");
+    static_assert(Shape == VectorShape::Row || Cols == 1, "");
+    RegisterBlockType result;
+    for (int i = 0; i < Rows * Cols; i++) {
+      result.buf.reg[i] = src(0);
+    }
+    return result;
+  }
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
+  return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
+}
+
+template <typename RegisterBlockType, typename SrcObjectType>
+RegisterBlockType Load(const SrcObjectType& src, int pos) {
+  return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
+}
+
+template <typename RegisterBlockType>
+struct LoadContiguousImpl {
+  using ScalarType = typename RegisterBlockType::ScalarType;
+  static_assert(RegisterBlockType::kRegisterLanes == 1,
+                "This path is only for scalar values");
+  static RegisterBlockType Run(const ScalarType* src) {
+    RegisterBlockType result;
+    for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
+      result.buf.reg[i] = src[i];
+    }
+    return result;
+  }
+};
+
+template <typename RegisterBlockType>
+RegisterBlockType LoadContiguous(
+    const typename RegisterBlockType::ScalarType* src) {
+  return LoadContiguousImpl<RegisterBlockType>::Run(src);
+}
+
+template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
+struct LoadForBroadcastingShape {};
+
+template <int BroadcastRows, int BroadcastCols, typename ScalarType,
+          VectorShape Shape>
+struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
+                                VectorMap<ScalarType, Shape>> {
+  static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
+  static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
+};
+
+template <int BroadcastRows, int BroadcastCols, typename ScalarType,
+          VectorShape Shape>
+struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
+                                VectorDup<ScalarType, Shape>> {
+  static constexpr int kRows = 1;
+  static constexpr int kCols = 1;
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+struct LoadForBroadcastingRegisterBlock {
+  using Shape =
+      LoadForBroadcastingShape<RegisterBlockType::kRows,
+                               RegisterBlockType::kCols, SrcObjectType>;
+  using ScalarType = typename RegisterBlockType::ScalarType;
+  using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+struct LoadForBroadcastingImpl {
+  static_assert(std::is_same<SrcObjectType, void>::value,
+                "This generic impl should never be hit");
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+          VectorShape Shape>
+struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
+                               VectorMap<SrcScalarType, Shape>> {
+  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+  using SrcObjectType = VectorMap<SrcScalarType, Shape>;
+  using ResultBlockType =
+      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                                SrcObjectType>::Type;
+  static_assert(ResultBlockType::kRegisterLanes == 1,
+                "This path is only for scalar values");
+  static ResultBlockType Run(const SrcObjectType& src, int pos) {
+    ResultBlockType result;
+    for (int c = 0; c < ResultBlockType::kCols; c++) {
+      for (int r = 0; r < ResultBlockType::kRows; r++) {
+        const int i = Shape == VectorShape::Col ? r : c;
+        result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
+      }
+    }
+    return result;
+  }
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+          VectorShape Shape>
+struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
+                               VectorDup<SrcScalarType, Shape>> {
+  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+  using SrcObjectType = VectorDup<SrcScalarType, Shape>;
+  using ResultBlockType =
+      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                                SrcObjectType>::Type;
+  static_assert(ResultBlockType::kRegisterLanes == 1,
+                "This path is only for scalar values");
+  static ResultBlockType Run(const SrcObjectType& src, int) {
+    ResultBlockType result;
+    for (int c = 0; c < ResultBlockType::kCols; c++) {
+      for (int r = 0; r < ResultBlockType::kRows; r++) {
+        result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
+      }
+    }
+    return result;
+  }
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                          SrcObjectType>::Type
+LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
+  return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
+      src, row, col);
+}
+
+template <typename RegisterBlockType, typename SrcObjectType>
+typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                          SrcObjectType>::Type
+LoadForBroadcasting(const SrcObjectType& src, int pos) {
+  return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
+                                                                        pos);
+}
+
+template <int ConstantValue, typename RegisterBlockType>
+struct AddConstantImpl {
+  static void Run(RegisterBlockType* block) {
+    using RegisterType = typename RegisterBlockType::RegisterType;
+    const RegisterType dup = Dup<RegisterType>(ConstantValue);
+    for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
+      block->buf.reg[i] = Add(block->buf.reg[i], dup);
+    }
+  }
+};
+
+template <typename RegisterBlockType>
+struct AddConstantImpl<0, RegisterBlockType> {
+  static void Run(RegisterBlockType*) {
+    // This is a no-op.
+  }
+};
+
+template <int ConstantValue, typename RegisterBlockType>
+void AddConstant(RegisterBlockType* block) {
+  AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
+}
+
+template <int N>
+using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
+template <int N>
+using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
+template <int R, int C>
+using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
+template <int R, int C>
+using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
+
+}  // end namespace gemmlowp
+
+#if defined GEMMLOWP_NEON
+#include "simd_wrappers_neon.h"
+#elif defined GEMMLOWP_SSE4
+#include "simd_wrappers_sse.h"
+#endif
+
+#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
diff --git a/internal/simd_wrappers_common_neon_sse.h b/internal/simd_wrappers_common_neon_sse.h
new file mode 100644
index 0000000..3830eb1
--- /dev/null
+++ b/internal/simd_wrappers_common_neon_sse.h
@@ -0,0 +1,646 @@
+// Copyright 2015 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code
+
+#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
+#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
+
+#include "simd_wrappers.h"
+
+namespace gemmlowp {
+
+template <typename SrcScalarType, int N>
+struct LoadImpl<RegBlockInt32<4, N>,
+                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
+  static RegBlockInt32<4, N> Run(
+      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
+      int col) {
+    RegBlockInt32<4, N> result;
+    for (int i = 0; i < N; i++) {
+      result.buf.reg[i] = LoadInt32x4(src.data(row, col + i));
+    }
+    return result;
+  }
+};
+
+template <typename SrcScalarType, int N>
+struct LoadImpl<RegBlockInt32<8, N>,
+                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
+  static RegBlockInt32<8, N> Run(
+      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
+      int col) {
+    RegBlockInt32<8, N> result;
+    for (int i = 0; i < N; i++) {
+      result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i));
+      result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i));
+    }
+    return result;
+  }
+};
+
+template <typename SrcScalarType>
+struct LoadImpl<RegBlockInt32<1, 4>,
+                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
+  static RegBlockInt32<1, 4> Run(
+      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
+      int col) {
+    RegBlockInt32<1, 4> result;
+    std::int32_t buf[4];
+    for (int i = 0; i < 4; i++) {
+      buf[i] = src(row, col + i);
+    }
+    result.buf.reg[0] = LoadInt32x4(buf);
+    return result;
+  }
+};
+
+template <typename SrcScalarType>
+struct LoadImpl<RegBlockInt32<1, 8>,
+                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
+  static RegBlockInt32<1, 8> Run(
+      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
+      int col) {
+    RegBlockInt32<1, 8> result;
+    std::int32_t buf[8];
+    for (int i = 0; i < 8; i++) {
+      buf[i] = src(row, col + i);
+    }
+    result.buf.reg[0] = LoadInt32x4(buf);
+    result.buf.reg[1] = LoadInt32x4(buf + 4);
+    return result;
+  }
+};
+
+template <typename SrcScalarType>
+struct LoadImpl<RegBlockInt32<4, 1>,
+                VectorMap<SrcScalarType, VectorShape::Col>> {
+  static RegBlockInt32<4, 1> Run(
+      const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = LoadInt32x4(src.data(pos));
+    return result;
+  }
+};
+
+template <typename SrcScalarType>
+struct LoadImpl<RegBlockInt32<4, 1>,
+                VectorDup<SrcScalarType, VectorShape::Col>> {
+  static RegBlockInt32<4, 1> Run(
+      const VectorDup<SrcScalarType, VectorShape::Col>& src, int) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = LoadInt32x4(src(0));
+    return result;
+  }
+};
+
+template <typename SrcScalarType, int N>
+struct LoadForBroadcastingImpl<RegBlockInt32<4, N>,
+                               VectorMap<SrcScalarType, VectorShape::Col>> {
+  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
+  using RegisterBlockType = RegBlockInt32<4, N>;
+  using ResultBlockType =
+      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                                SrcObjectType>::Type;
+
+  static ResultBlockType Run(const SrcObjectType& src, int pos) {
+    ResultBlockType result;
+    static_assert(ResultBlockType::kRegisterCount == 1, "");
+    result.buf.reg[0] = LoadInt32x4(src.data(pos));
+    return result;
+  }
+};
+
+template <typename SrcScalarType, int N>
+struct LoadForBroadcastingImpl<RegBlockInt32<8, N>,
+                               VectorMap<SrcScalarType, VectorShape::Col>> {
+  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
+  using RegisterBlockType = RegBlockInt32<8, N>;
+  using ResultBlockType =
+      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                                SrcObjectType>::Type;
+
+  static ResultBlockType Run(const SrcObjectType& src, int pos) {
+    ResultBlockType result;
+    static_assert(ResultBlockType::kRegisterCount == 2, "");
+    result.buf.reg[0] = LoadInt32x4(src.data(pos));
+    result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
+    return result;
+  }
+};
+
+template <typename SrcScalarType>
+struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>,
+                               VectorMap<SrcScalarType, VectorShape::Row>> {
+  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
+  using RegisterBlockType = RegBlockInt32<4, 1>;
+  using ResultBlockType =
+      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                                SrcObjectType>::Type;
+
+  static ResultBlockType Run(const SrcObjectType& src, int pos) {
+    ResultBlockType result;
+    result.buf.reg[0] = src(pos);
+    return result;
+  }
+};
+
+template <typename SrcScalarType, int N>
+struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>,
+                               VectorMap<SrcScalarType, VectorShape::Row>> {
+  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
+  using RegisterBlockType = RegBlockInt32<N, 4>;
+  using ResultBlockType =
+      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                                SrcObjectType>::Type;
+
+  static ResultBlockType Run(const SrcObjectType& src, int pos) {
+    ResultBlockType result;
+    static_assert(ResultBlockType::kRegisterCount == 1, "");
+    result.buf.reg[0] = LoadInt32x4(src.data(pos));
+    return result;
+  }
+};
+
+template <typename SrcScalarType, int N>
+struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>,
+                               VectorMap<SrcScalarType, VectorShape::Row>> {
+  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
+  using RegisterBlockType = RegBlockInt32<N, 8>;
+  using ResultBlockType =
+      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+                                                SrcObjectType>::Type;
+
+  static ResultBlockType Run(const SrcObjectType& src, int pos) {
+    ResultBlockType result;
+    static_assert(ResultBlockType::kRegisterCount == 2, "");
+    result.buf.reg[0] = LoadInt32x4(src.data(pos));
+    result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 4> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]);
+    result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = Add(lhs.buf.reg[i], p);
+    }
+    return result;
+  }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
+    }
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+    result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+    result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+    result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+    result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 4> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
+    result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
+    result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]);
+    result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]);
+    result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]);
+    result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]);
+    result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 8>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
+    result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
+    return result;
+  }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 8> result;
+    result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 * 1x1
+template <>
+struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+    return result;
+  }
+};
+
+// 4x1 := 4x1 * 4x1
+template <>
+struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 1> result;
+    result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 1x4 := 1x4 * 1x4
+template <>
+struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 1x4 := 1x4 * 1x1
+template <>
+struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<1, 4> result;
+    result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    return result;
+  }
+};
+
+// 4x4 := 4x4 * 1x4
+template <>
+struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<4, 4> result;
+    const Int32x4 p = rhs.buf.reg[0];
+    result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p);
+    result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p);
+    result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p);
+    result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p);
+    return result;
+  }
+};
+
+// 4x4 := 4x4 * 4x1
+template <>
+struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
+  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+                                 const RegBlockInt32<4, 1>& rhs) {
+    RegBlockInt32<4, 4> result;
+    const Int32x4 p = rhs.buf.reg[0];
+    result.buf.reg[0] = Mul(lhs.buf.reg[0], p);
+    result.buf.reg[1] = Mul(lhs.buf.reg[1], p);
+    result.buf.reg[2] = Mul(lhs.buf.reg[2], p);
+    result.buf.reg[3] = Mul(lhs.buf.reg[3], p);
+    return result;
+  }
+};
+
+// 8x1 := 8x1 * 1x1
+template <>
+struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<1, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    const std::int32_t p = rhs.buf.reg[0];
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = Mul(lhs.buf.reg[i], p);
+    }
+    return result;
+  }
+};
+
+// 8x1 := 8x1 * 8x1
+template <>
+struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 1> result;
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]);
+    }
+    return result;
+  }
+};
+
+// 8x4 := 8x4 * 1x4
+template <>
+struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<1, 4>& rhs) {
+    RegBlockInt32<8, 4> result;
+    const Int32x4 p = rhs.buf.reg[0];
+    for (int i = 0; i < 2; i++) {
+      result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p);
+      result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p);
+      result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p);
+      result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p);
+    }
+    return result;
+  }
+};
+
+// 8x4 := 8x4 * 8x1
+template <>
+struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
+  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+                                 const RegBlockInt32<8, 1>& rhs) {
+    RegBlockInt32<8, 4> result;
+    const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]};
+    for (int i = 0; i < 4; i++) {
+      for (int j = 0; j < 2; j++) {
+        const int k = j + 2 * i;
+        result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]);
+      }
+    }
+    return result;
+  }
+};
+
+// Rx1 += Rx1 * 1x1
+template <int Rows>
+struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<Rows, 1>> {
+  static void Run(const RegBlockInt32<Rows, 1>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) {
+    const std::int32_t p = rhs.buf.reg[0];
+    for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) {
+      MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
+    }
+  }
+};
+
+// RxC += Rx1 * 1x1
+template <int Rows, int Cols>
+struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<Rows, Cols>> {
+  static void Run(const RegBlockInt32<Rows, 1>& lhs,
+                  const RegBlockInt32<1, 1>& rhs,
+                  RegBlockInt32<Rows, Cols>* acc) {
+    const std::int32_t p = rhs.buf.reg[0];
+    static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
+    for (int i = 0; i < kRegsPerCol; i++) {
+      const Int32x4 q = Mul(lhs.buf.reg[i], p);
+      for (int j = 0; j < Cols; j++) {
+        acc->buf.reg[i + j * kRegsPerCol] =
+            Add(acc->buf.reg[i + j * kRegsPerCol], q);
+      }
+    }
+  }
+};
+
+// 1xC += 1xC * 1x1
+template <int Cols>
+struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<1, Cols>> {
+  static void Run(const RegBlockInt32<1, Cols>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
+    const std::int32_t p = rhs.buf.reg[0];
+    for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
+      MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
+    }
+  }
+};
+
+// RxC += 1x1 * 1x1
+template <int Rows, int Cols>
+struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<Rows, Cols>> {
+  static void Run(const RegBlockInt32<1, 1>& lhs,
+                  const RegBlockInt32<1, 1>& rhs,
+                  RegBlockInt32<Rows, Cols>* acc) {
+    const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
+    for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) {
+      acc->buf.reg[i] = Add(acc->buf.reg[i], p);
+    }
+  }
+};
+
+// 1x1 += 1x1 * 1x1
+template <>
+struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<1, 1>> {
+  static void Run(const RegBlockInt32<1, 1>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) {
+    MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]);
+  }
+};
+
+// Rx4 += Rx1 * 1x4
+template <int Rows>
+struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>,
+                           RegBlockInt32<Rows, 4>> {
+  static void Run(const RegBlockInt32<Rows, 1>& lhs,
+                  const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) {
+    const Int32x4 p = rhs.buf.reg[0];
+    static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
+    for (int i = 0; i < kRegsPerCol; i++) {
+      MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]);
+      MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]);
+      MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]);
+      MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]);
+    }
+  }
+};
+
+// Rx4 += 1x4 * 1x1
+template <int Rows>
+struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<Rows, 4>> {
+  static void Run(const RegBlockInt32<1, 4>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) {
+    const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    Int32x4 q[4];
+    q[0] = DupLane<0>(p);
+    q[1] = DupLane<1>(p);
+    q[2] = DupLane<2>(p);
+    q[3] = DupLane<3>(p);
+    static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
+    for (int i = 0; i < kRegsPerCol; i++) {
+      for (int j = 0; j < 4; j++) {
+        acc->buf.reg[i + j * kRegsPerCol] =
+            Add(q[j], acc->buf.reg[i + j * kRegsPerCol]);
+      }
+    }
+  }
+};
+
+// 1xC += 1x1 * 1x1
+template <int Cols>
+struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<1, Cols>> {
+  static void Run(const RegBlockInt32<1, 1>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
+    const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
+    for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
+      acc->buf.reg[i] = Add(acc->buf.reg[i], p);
+    }
+  }
+};
+
+// 1x4 += 1x4 * 1x1
+template <>
+struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<1, 4>> {
+  static void Run(const RegBlockInt32<1, 4>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) {
+    const std::int32_t p = rhs.buf.reg[0];
+    MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
+  }
+};
+
+// 4xC += 4x1 * 1x1
+template <int Cols>
+struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<4, Cols>> {
+  static void Run(const RegBlockInt32<4, 1>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) {
+    const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
+    for (int i = 0; i < Cols; i++) {
+      acc->buf.reg[i] = Add(p, acc->buf.reg[i]);
+    }
+  }
+};
+
+// 4x1 += 4x1 * 1x1
+template <>
+struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
+                           RegBlockInt32<4, 1>> {
+  static void Run(const RegBlockInt32<4, 1>& lhs,
+                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) {
+    const std::int32_t p = rhs.buf.reg[0];
+    MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
+  }
+};
+
+}  // namespace gemmlowp
+
+#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h
new file mode 100644
index 0000000..c992b15
--- /dev/null
+++ b/internal/simd_wrappers_neon.h
@@ -0,0 +1,150 @@
+// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// simd_wrappers_neon.h: NEON specialization of simd_wrappers.h
+
+#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
+#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
+
+#include <arm_neon.h>
+
+namespace gemmlowp {
+
+using Int32x4 = int32x4_t;
+using Uint8x8 = uint8x8_t;
+
+template <int ScalarCount>
+struct RegisterType<std::int32_t, ScalarCount> {
+  using Type =
+      typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
+};
+
+template <int ScalarCount>
+struct RegisterType<std::uint8_t, ScalarCount> {
+  using Type = typename std::conditional<
+      ScalarCount >= 8, Uint8x8,
+      typename std::conditional<ScalarCount >= 4, std::uint32_t,
+                                std::uint8_t>::type>::type;
+};
+
+inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
+
+inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
+  vst1q_s32(dst, value);
+}
+
+template <int Lane>
+std::int32_t GetLane(Int32x4 value) {
+  return vgetq_lane_s32(value, Lane);
+}
+
+template <int Lane>
+Int32x4 DupLane(Int32x4 value) {
+  switch (Lane) {
+    case 0:
+      return vdupq_lane_s32(vget_low_s32(value), 0);
+    case 1:
+      return vdupq_lane_s32(vget_low_s32(value), 1);
+    case 2:
+      return vdupq_lane_s32(vget_high_s32(value), 0);
+    case 3:
+      return vdupq_lane_s32(vget_high_s32(value), 1);
+    default:
+      static_assert(Lane >= 0 && Lane <= 3, "");
+      return vdupq_n_s32(0);
+  }
+}
+
+inline Int32x4 Mul(Int32x4 a, std::int32_t b) { return vmulq_n_s32(a, b); }
+
+inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); }
+
+inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); }
+
+inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
+  return vqrdmulhq_n_s32(a, b);
+}
+
+template <int Lane>
+Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
+  switch (Lane) {
+    case 0:
+      return vmulq_lane_s32(a, vget_low_s32(b), 0);
+    case 1:
+      return vmulq_lane_s32(a, vget_low_s32(b), 1);
+    case 2:
+      return vmulq_lane_s32(a, vget_high_s32(b), 0);
+    case 3:
+      return vmulq_lane_s32(a, vget_high_s32(b), 1);
+    default:
+      static_assert(Lane >= 0 && Lane <= 3, "");
+      return vdupq_n_s32(0);
+  }
+}
+
+inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+  *acc = vmlaq_s32(*acc, lhs, rhs);
+}
+
+inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
+  *acc = vmlaq_n_s32(*acc, lhs, rhs);
+}
+
+template <int Lane>
+inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+  switch (Lane) {
+    case 0:
+      *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 0);
+      break;
+    case 1:
+      *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 1);
+      break;
+    case 2:
+      *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 0);
+      break;
+    case 3:
+      *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 1);
+      break;
+    default:
+      static_assert(Lane >= 0 && Lane <= 3, "");
+  }
+}
+
+template <>
+struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
+  static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
+    RegBlockUint8<8, 8> result;
+    for (int i = 0; i < 8; i++) {
+      result.buf.reg[i] = vld1_u8(src + 8 * i);
+    }
+    return result;
+  }
+};
+
+template <>
+struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
+  static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
+    RegBlockInt32<8, 8> result;
+    for (int i = 0; i < 16; i++) {
+      result.buf.reg[i] = vld1q_s32(src + 4 * i);
+    }
+    return result;
+  }
+};
+
+}  // end namespace gemmlowp
+
+#include "simd_wrappers_common_neon_sse.h"
+
+#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
diff --git a/internal/simd_wrappers_sse.h b/internal/simd_wrappers_sse.h
new file mode 100644
index 0000000..6480b66
--- /dev/null
+++ b/internal/simd_wrappers_sse.h
@@ -0,0 +1,123 @@
+// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// simd_wrappers_neon.h: SSE SIMD wrappers
+
+#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_SSE_H_
+#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_SSE_H_
+
+#include <smmintrin.h>
+
+namespace gemmlowp {
+
+using Int32x4 = __m128i;
+using Uint8x16 = __m128i;
+
+template <int ScalarCount>
+struct RegisterType<std::int32_t, ScalarCount> {
+  using Type =
+      typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
+};
+
+template <int ScalarCount>
+struct RegisterType<std::uint8_t, ScalarCount> {
+  using Type = typename std::conditional<
+      ScalarCount >= 16, Uint8x16,
+      typename std::conditional<ScalarCount >= 4, std::uint32_t,
+                                std::uint8_t>::type>::type;
+};
+
+inline Int32x4 LoadInt32x4(const std::int32_t* src) {
+  return _mm_loadu_si128(reinterpret_cast<const Int32x4*>(src));
+}
+
+inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
+  _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value);
+}
+
+inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
+  return _mm_loadu_si128(reinterpret_cast<const Uint8x16*>(src));
+}
+
+inline void StoreUint8x16(std::uint8_t* dst, Uint8x16 value) {
+  _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value);
+}
+
+template <int Lane>
+std::int32_t GetLane(Int32x4 value) {
+  return _mm_extract_epi32(value, Lane);
+}
+
+template <int Lane>
+Int32x4 DupLane(Int32x4 value) {
+  return _mm_shuffle_epi32(value, _MM_SHUFFLE(Lane, Lane, Lane, Lane));
+}
+
+inline Int32x4 Mul(Int32x4 a, std::int32_t b) {
+  return Mul(a, Dup<Int32x4>(b));
+}
+
+inline Int32x4 Min(Int32x4 a, Int32x4 b) { return _mm_min_epi32(a, b); }
+
+inline Int32x4 Max(Int32x4 a, Int32x4 b) { return _mm_max_epi32(a, b); }
+
+inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
+  return SaturatingRoundingDoublingHighMul(a, Dup<Int32x4>(b));
+}
+
+template <int Lane>
+Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
+  return Mul(a, DupLane<Lane>(b));
+}
+
+inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+  *acc = Add(*acc, Mul(lhs, rhs));
+}
+
+inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
+  *acc = Add(*acc, Mul(lhs, rhs));
+}
+
+template <int Lane>
+inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+  *acc = Add(*acc, MulByRhsLane<Lane>(lhs, rhs));
+}
+
+template <>
+struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
+  static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
+    RegBlockUint8<8, 8> result;
+    for (int i = 0; i < 4; i++) {
+      result.buf.reg[i] = LoadUint8x16(src + 16 * i);
+    }
+    return result;
+  }
+};
+
+template <>
+struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
+  static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
+    RegBlockInt32<8, 8> result;
+    for (int i = 0; i < 16; i++) {
+      result.buf.reg[i] = LoadInt32x4(src + 4 * i);
+    }
+    return result;
+  }
+};
+
+}  // end namespace gemmlowp
+
+#include "simd_wrappers_common_neon_sse.h"
+
+#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_SSE_H_
diff --git a/internal/single_thread_gemm.h b/internal/single_thread_gemm.h
index f40ba55..3d430c5 100644
--- a/internal/single_thread_gemm.h
+++ b/internal/single_thread_gemm.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -28,20 +28,36 @@
 #include "pack.h"
 #include "unpack.h"
 
+#ifdef GEMMLOWP_PROFILING_SIZES
+#ifndef GEMMLOWP_PROFILING
+#error GEMMLOWP_PROFILING_SIZES without GEMMLOWP_PROFILING
+#endif
+#include <string>
+#include <unordered_map>
+#endif
+
 namespace gemmlowp {
 
 class SingleThreadGemmContext {
  public:
   Allocator* allocator() { return &allocator_; }
 
+  void set_l1_bytes_to_use(int n) { l1_bytes_to_use_ = n; }
+  void set_l2_bytes_to_use(int n) { l2_bytes_to_use_ = n; }
+  void set_l2_rhs_factor(float n) { l2_rhs_factor_ = n; }
+
+  int l1_bytes_to_use() const { return l1_bytes_to_use_; }
+  int l2_bytes_to_use() const { return l2_bytes_to_use_; }
+  float l2_rhs_factor() const { return l2_rhs_factor_; }
+
  protected:
   Allocator allocator_;
-};
 
-typedef VectorMap<const int32_t, VectorShape::Col> OffsetColMap;
-typedef VectorMap<const int32_t, VectorShape::Row> OffsetRowMap;
-typedef VectorDup<const int32_t, VectorShape::Col> OffsetColDup;
-typedef VectorDup<const int32_t, VectorShape::Row> OffsetRowDup;
+  // The cache configurationt to use.
+  int l1_bytes_to_use_ = kDefaultL1CacheSize;
+  int l2_bytes_to_use_ = kDefaultL2CacheSize;
+  float l2_rhs_factor_ = kDefaultL2RhsFactor;
+};
 
 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
           typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
@@ -62,49 +78,75 @@
   int cols = result->cols();
   int depth = lhs.cols();
 
+  // zero sizes should have been caught earlier and early-returned.
   assert(rows > 0);
   assert(cols > 0);
   assert(depth > 0);
 
+  // The case of rows<cols should have been caught earlier and transposed.
+  assert(rows >= cols);
+
   Allocator* allocator = context->allocator();
 
   BlockParams block_params;
-  block_params.Init<KernelFormat>(rows, cols, depth, 1);
+  block_params.Init<KernelFormat>(rows, cols, depth, 1,
+                                  context->l1_bytes_to_use(),
+                                  context->l2_bytes_to_use(),
+                                  context->l2_rhs_factor());
 
-  PackedSideBlock<typename KernelFormat::Lhs> packed_lhs(
-      Side::Lhs, allocator, block_params);
-  PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(
-      Side::Rhs, allocator, block_params);
+#ifdef GEMMLOWP_PROFILING_SIZES
+  // Using a static map of label strings. Not reentrant at all!
+  static std::unordered_map<std::uint64_t, std::string> labels_map;
+  std::uint64_t sizes_hash = static_cast<std::uint64_t>(rows) ^
+                             (static_cast<std::uint64_t>(depth) << 16) ^
+                             (static_cast<std::uint64_t>(cols) << 32);
+  if (!labels_map.count(sizes_hash)) {
+    char label[256];
+    snprintf(label, sizeof(label),
+             "(rows = %d, depth = %d, cols = %d, l2_rows = %d, l2_depth = %d, "
+             "l2_cols = %d, l1_rows = %d, l1_depth = %d, l1_cols = %d)",
+             rows, depth, cols, block_params.l2_rows, block_params.l2_depth,
+             block_params.l2_cols, block_params.l1_rows, block_params.l1_depth,
+             block_params.l1_cols);
+    labels_map[sizes_hash] = label;
+  }
+  ScopedProfilingLabel size_label(labels_map[sizes_hash].c_str());
+#endif
+
+  PackedSideBlock<typename KernelFormat::Lhs> packed_lhs(Side::Lhs, allocator,
+                                                         block_params);
+  PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(Side::Rhs, allocator,
+                                                         block_params);
 
   PackedResult packed_result(allocator, block_params);
 
   allocator->Commit();
 
-  const bool pack_rhs_once = block_params.l2_cols == cols;
+  const bool pack_rhs_once = block_params.l2_cols >= cols;
 
   if (pack_rhs_once) {
-    PackRhs<BitDepthParams>(&packed_rhs, rhs);
+    PackRhs(&packed_rhs, rhs);
   }
 
   for (int r = 0; r < rows; r += block_params.l2_rows) {
     int rs = std::min(block_params.l2_rows, rows - r);
 
-    PackLhs<BitDepthParams>(&packed_lhs, lhs.block(r, 0, rs, depth));
+    PackLhs(&packed_lhs, lhs.block(r, 0, rs, depth));
 
     for (int c = 0; c < cols; c += block_params.l2_cols) {
       int cs = std::min(block_params.l2_cols, cols - c);
 
       if (!pack_rhs_once) {
-        PackRhs<BitDepthParams>(&packed_rhs, rhs.block(0, c, depth, cs));
+        PackRhs(&packed_rhs, rhs.block(0, c, depth, cs));
       }
 
-      Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs);
+      Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs,
+              depth);
 
-      auto result_block = result->block(r, c, rs, cs);
-      UnpackResult<BitDepthParams>(&result_block, packed_result, depth,
-                                   packed_lhs.sums_of_each_slice(),
-                                   packed_rhs.sums_of_each_slice(),
-                                   lhs_offset, rhs_offset, output_pipeline);
+      UnpackResult<KernelFormat>(
+          result, MatrixBlockBounds(r, c, rs, cs), packed_result, depth,
+          packed_lhs.sums_of_each_slice(), packed_rhs.sums_of_each_slice(),
+          lhs_offset.block(r, rs), rhs_offset.block(c, cs), output_pipeline);
     }
   }
 
diff --git a/internal/unpack.h b/internal/unpack.h
index e25372a..33aee13 100644
--- a/internal/unpack.h
+++ b/internal/unpack.h
@@ -1,4 +1,4 @@
-// Copyright 2015 Google Inc. All Rights Reserved.
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -55,110 +55,224 @@
   const BlockParams& block_params_;
 };
 
-template <std::uint32_t numerator, std::uint32_t denominator>
-std::int32_t RoundingMultiplyByConstantFraction(std::int32_t x) {
-  if (numerator == denominator) {
-    return x;
+struct MatrixBlockBounds {
+  int start_row;
+  int start_col;
+  int rows;
+  int cols;
+
+  MatrixBlockBounds(int start_row_, int start_col_, int rows_, int cols_)
+      : start_row(start_row_),
+        start_col(start_col_),
+        rows(rows_),
+        cols(cols_) {}
+};
+
+template <int Rows, int Cols, typename SrcMapType>
+void PrefetchResultBlock(const SrcMapType& src,
+                         const VectorMap<const std::int32_t, VectorShape::Col>&
+                             lhs_sums_of_each_slice,
+                         int src_row, int src_col) {
+  const std::int32_t* src_data = src.data(src_row, src_col);
+  const int src_stride = src.stride();
+  const std::int32_t* lhs_sums_data = lhs_sums_of_each_slice.data(src_row);
+  for (int r = 0; r < Rows; r += 4) {
+    Prefetch(lhs_sums_data + r);
   }
-
-  // We'll use only signed arithmetic here. This is
-  // simpler (since this function operates on signed int32's) and
-  // more friendly to ARM NEON, where this allows us to use the
-  // VQRDMULH instruction.
-  static const std::int32_t int_quotient =
-      (numerator + denominator / 2) / denominator;
-  static const std::int32_t remaining_numerator =
-      numerator - int_quotient * denominator;
-  static const std::int32_t scaled_remaining_numerator =
-      static_cast<std::int32_t>(
-          (static_cast<std::int64_t>(remaining_numerator) * (1ll << 31)) /
-          denominator);
-
-  const std::int64_t scaled_remaining_product =
-      static_cast<std::int64_t>(x) *
-      static_cast<std::int64_t>(scaled_remaining_numerator);
-
-  const std::int32_t scaled_remaining_product_nudge =
-      (scaled_remaining_product > 0 ? 1 : -1) * (1 << 30);
-
-  const std::int32_t remaining_product = static_cast<std::int32_t>(
-      (scaled_remaining_product + scaled_remaining_product_nudge) / (1u << 31));
-
-  return x * int_quotient + remaining_product;
+  for (int c = 0; c < Cols; c++) {
+    for (int r = 0; r < Rows; r += 4) {
+      Prefetch(src_data + r + c * src_stride);
+    }
+  }
 }
 
-template <typename BitDepthParams, typename ResultBlockType,
+template <typename KernelFormat, typename RegisterBlockType,
+          typename SrcMapType, typename LhsOffset, typename RhsOffset,
+          typename OutputPipelineExecutorType, typename DstType>
+void UnpackResultBlock(const SrcMapType& src,
+                       const OutputPipelineExecutorType& executor, DstType* dst,
+                       const VectorMap<const std::int32_t, VectorShape::Col>&
+                           lhs_sums_of_each_slice,
+                       const VectorMap<const std::int32_t, VectorShape::Row>&
+                           rhs_sums_of_each_slice,
+                       const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
+                       int depth, int src_row, int src_col, int src_global_row,
+                       int src_global_col, int dst_row, int dst_col) {
+  using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
+  using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
+  static constexpr int KernelLhsZeroPointInput =
+      ZeroPointInputValue<KernelLhsScalar>::kValue;
+  static constexpr int KernelRhsZeroPointInput =
+      ZeroPointInputValue<KernelRhsScalar>::kValue;
+  auto acc = Load<RegisterBlockType>(src, src_row, src_col);
+  const auto& lhs_sums_of_each_slice_block =
+      LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
+  const auto& rhs_sums_of_each_slice_block =
+      LoadForBroadcasting<RegisterBlockType>(rhs_sums_of_each_slice, src_col);
+  auto lhs_offset_block =
+      LoadForBroadcasting<RegisterBlockType>(lhs_offset, src_row);
+  auto rhs_offset_block =
+      LoadForBroadcasting<RegisterBlockType>(rhs_offset, src_col);
+  AddConstant<KernelLhsZeroPointInput>(&lhs_offset_block);
+  AddConstant<KernelRhsZeroPointInput>(&rhs_offset_block);
+  BroadcastMulAdd(lhs_sums_of_each_slice_block, rhs_offset_block, &acc);
+  for (int i = 0; i < decltype(rhs_offset_block)::kRegisterCount; i++) {
+    rhs_offset_block.buf.reg[i] = Mul(rhs_offset_block.buf.reg[i], depth);
+  }
+  BroadcastMulAdd(BroadcastAdd(rhs_sums_of_each_slice_block, rhs_offset_block),
+                  lhs_offset_block, &acc);
+  executor.Execute(acc, dst, src_global_row, src_global_col, dst_row, dst_col);
+}
+
+template <typename KernelFormat, typename ResultBlockType,
           typename PackedResultType, typename LhsOffset, typename RhsOffset,
           typename OutputPipelineType>
-struct UnpackResultImplGeneric {
-  static void Unpack(ResultBlockType* dst, const PackedResultType& src,
-                     int depth, const std::int32_t* lhs_sums_of_each_slice,
-                     const std::int32_t* rhs_sums_of_each_slice,
-                     const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
-                     const OutputPipelineType& output_pipeline) {
-    auto src_map = src.Map();
-    // No top-level blocking in the depth dimension at the moment.
-    // Too much loss of precision.
-    const int kLhsBits = BitDepthParams::LhsBitDepth::kBits;
-    const int kRhsBits = BitDepthParams::RhsBitDepth::kBits;
-    const std::int32_t kLhsMax = (1 << kLhsBits) - 1;
-    const std::int32_t kRhsMax = (1 << kRhsBits) - 1;
-    OutputPipelineExecutor<OutputPipelineType, FragmentInt32x1x1>
-        output_pipeline_executor(output_pipeline);
-    for (int c = 0; c < dst->cols(); c++) {
-      for (int r = 0; r < dst->rows(); r++) {
-        // To understand this code, read
-        //   doc/low-precision.txt
-        //   doc/less-than-8-bit.txt
-        // We have 4 terms to sum: xx, x1, 1x, 11.
-        // In case of requantization, we first need to scale them back
-        // to the original scale, using RoundingMultiplyByConstantFraction.
-        std::int32_t raw_xx = src_map(r, c);
-        std::int32_t raw_x1 = lhs_sums_of_each_slice[r] * rhs_offset(c);
-        std::int32_t raw_1x = rhs_sums_of_each_slice[c] * lhs_offset(r);
-        std::int32_t term_xx =
-            RoundingMultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>(
-                raw_xx);
-        std::int32_t term_x1 =
-            RoundingMultiplyByConstantFraction<255, kLhsMax>(raw_x1);
-        std::int32_t term_1x =
-            RoundingMultiplyByConstantFraction<255, kRhsMax>(raw_1x);
-        std::int32_t term_11 = lhs_offset(r) * rhs_offset(c) * depth;
-        // Sum the 4 terms.
-        FragmentInt32x1x1 sum = term_xx + term_x1 + term_1x + term_11;
+void UnpackResult(ResultBlockType* dst, const MatrixBlockBounds& dst_block,
+                  const PackedResultType& src, int depth,
+                  const std::int32_t* lhs_sums_of_each_slice_ptr,
+                  const std::int32_t* rhs_sums_of_each_slice_ptr,
+                  const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
+                  const OutputPipelineType& output_pipeline) {
+  ScopedProfilingLabel label(ResultBlockType::kOrder == MapOrder::ColMajor
+                                 ? "unpack to column-major"
+                                 : "unpack to row-major");
+  assert(dst_block.start_row >= 0);
+  assert(dst_block.start_row + dst_block.rows <= dst->rows());
+  assert(dst_block.start_col >= 0);
+  assert(dst_block.start_col + dst_block.cols <= dst->cols());
+  const auto src_map = src.Map();
+  const VectorMap<const std::int32_t, VectorShape::Col> lhs_sums_of_each_slice(
+      lhs_sums_of_each_slice_ptr, dst_block.rows);
+  const VectorMap<const std::int32_t, VectorShape::Row> rhs_sums_of_each_slice(
+      rhs_sums_of_each_slice_ptr, dst_block.cols);
+  using Int32x1x1 = RegisterBlock<std::int32_t, 1, 1>;
+  using Int32x4x1 = RegisterBlock<std::int32_t, 4, 1>;
+  using Int32x8x1 = RegisterBlock<std::int32_t, 8, 1>;
+  using Int32x1x4 = RegisterBlock<std::int32_t, 1, 4>;
+  using Int32x4x4 = RegisterBlock<std::int32_t, 4, 4>;
+  using Int32x8x4 = RegisterBlock<std::int32_t, 8, 4>;
 
-        output_pipeline_executor.Execute(sum, dst, r, c);
+  using DstScalarType = typename ResultBlockType::Scalar;
+  using DstScalarx8x8 = RegisterBlock<DstScalarType, 8, 8>;
+
+  OutputPipelineExecutor<OutputPipelineType, Int32x1x1>
+      output_pipeline_executor_1x1(output_pipeline);
+  OutputPipelineExecutor<OutputPipelineType, Int32x4x1>
+      output_pipeline_executor_4x1(output_pipeline);
+  OutputPipelineExecutor<OutputPipelineType, Int32x8x1>
+      output_pipeline_executor_8x1(output_pipeline);
+  OutputPipelineExecutor<OutputPipelineType, Int32x1x4>
+      output_pipeline_executor_1x4(output_pipeline);
+  OutputPipelineExecutor<OutputPipelineType, Int32x4x4>
+      output_pipeline_executor_4x4(output_pipeline);
+  OutputPipelineExecutor<OutputPipelineType, Int32x8x4>
+      output_pipeline_executor_8x4(output_pipeline);
+
+  int c8 = 0;
+  if (ResultBlockType::kOrder == MapOrder::RowMajor) {
+    for (; c8 <= dst_block.cols - 8; c8 += 8) {
+      PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, 0, c8);
+      int r = 0;
+      for (; r <= dst_block.rows - 8; r += 8) {
+        const int global_row = r + dst_block.start_row;
+        PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, r + 8, c8);
+        DstScalarType dst_colmajor_buf[64];
+        MatrixMap<DstScalarType, MapOrder::ColMajor> dst_colmajor_map(
+            dst_colmajor_buf, 8, 8);
+        for (int cx = 0; cx < 8; cx += 4) {
+          const int c = c8 + cx;
+          const int global_col = c + dst_block.start_col;
+          UnpackResultBlock<KernelFormat, Int32x8x4>(
+              src_map, output_pipeline_executor_8x4, &dst_colmajor_map,
+              lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
+              rhs_offset, depth, r, c, global_row, global_col, 0, cx);
+        }
+        StoreFinalOutput(LoadContiguous<DstScalarx8x8>(dst_colmajor_buf), dst,
+                         r + dst_block.start_row, c8 + dst_block.start_col);
+      }
+      for (; r <= dst_block.rows - 4; r += 4) {
+        const int global_row = r + dst_block.start_row;
+        for (int cx = 0; cx < 8; cx += 4) {
+          const int c = c8 + cx;
+          const int global_col = c + dst_block.start_col;
+          UnpackResultBlock<KernelFormat, Int32x4x4>(
+              src_map, output_pipeline_executor_4x4, dst,
+              lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
+              rhs_offset, depth, r, c, global_row, global_col, global_row,
+              global_col);
+        }
+      }
+      for (; r < dst_block.rows; r++) {
+        const int global_row = r + dst_block.start_row;
+        for (int cx = 0; cx < 8; cx += 4) {
+          const int c = c8 + cx;
+          const int global_col = c + dst_block.start_col;
+          UnpackResultBlock<KernelFormat, Int32x1x4>(
+              src_map, output_pipeline_executor_1x4, dst,
+              lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
+              rhs_offset, depth, r, c, global_row, global_col, global_row,
+              global_col);
+        }
       }
     }
   }
-};
-
-template <typename BitDepthParams, typename ResultBlockType,
-          typename PackedResultType, typename LhsOffset, typename RhsOffset,
-          typename OutputPipelineType>
-struct UnpackResultImpl
-    : UnpackResultImplGeneric<BitDepthParams, ResultBlockType, PackedResultType,
-                              LhsOffset, RhsOffset, OutputPipelineType> {};
-
-template <typename BitDepthParams, typename ResultBlockType,
-          typename PackedResultType, typename LhsOffset, typename RhsOffset,
-          typename OutputPipelineType>
-void UnpackResult(ResultBlockType* dst, const PackedResultType& src, int depth,
-                  const std::int32_t* lhs_sums_of_each_slice,
-                  const std::int32_t* rhs_sums_of_each_slice,
-                  const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
-                  const OutputPipelineType& output_pipeline) {
-  ScopedProfilingLabel label("unpack");
-  UnpackResultImpl<BitDepthParams, ResultBlockType, PackedResultType,
-                   LhsOffset, RhsOffset, OutputPipelineType>::Unpack(
-      dst, src, depth, lhs_sums_of_each_slice, rhs_sums_of_each_slice,
-      lhs_offset, rhs_offset, output_pipeline);
+  int c = c8;
+  for (; c <= dst_block.cols - 4; c += 4) {
+    const int global_col = c + dst_block.start_col;
+    PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, 0, c);
+    int r = 0;
+    for (; r <= dst_block.rows - 8; r += 8) {
+      const int global_row = r + dst_block.start_row;
+      PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, r + 8, c);
+      UnpackResultBlock<KernelFormat, Int32x8x4>(
+          src_map, output_pipeline_executor_8x4, dst, lhs_sums_of_each_slice,
+          rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
+          global_row, global_col, global_row, global_col);
+    }
+    for (; r <= dst_block.rows - 4; r += 4) {
+      const int global_row = r + dst_block.start_row;
+      UnpackResultBlock<KernelFormat, Int32x4x4>(
+          src_map, output_pipeline_executor_4x4, dst, lhs_sums_of_each_slice,
+          rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
+          global_row, global_col, global_row, global_col);
+    }
+    for (; r < dst_block.rows; r++) {
+      const int global_row = r + dst_block.start_row;
+      UnpackResultBlock<KernelFormat, Int32x1x4>(
+          src_map, output_pipeline_executor_1x4, dst, lhs_sums_of_each_slice,
+          rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
+          global_row, global_col, global_row, global_col);
+    }
+  }
+  for (; c < dst_block.cols; c++) {
+    const int global_col = c + dst_block.start_col;
+    PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, 0, c);
+    int r = 0;
+    for (; r <= dst_block.rows - 8; r += 8) {
+      const int global_row = r + dst_block.start_row;
+      PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, r + 8, c);
+      UnpackResultBlock<KernelFormat, Int32x8x1>(
+          src_map, output_pipeline_executor_8x1, dst, lhs_sums_of_each_slice,
+          rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
+          global_row, global_col, global_row, global_col);
+    }
+    for (; r <= dst_block.rows - 4; r += 4) {
+      const int global_row = r + dst_block.start_row;
+      UnpackResultBlock<KernelFormat, Int32x4x1>(
+          src_map, output_pipeline_executor_4x1, dst, lhs_sums_of_each_slice,
+          rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
+          global_row, global_col, global_row, global_col);
+    }
+    for (; r < dst_block.rows; r++) {
+      const int global_row = r + dst_block.start_row;
+      UnpackResultBlock<KernelFormat, Int32x1x1>(
+          src_map, output_pipeline_executor_1x1, dst, lhs_sums_of_each_slice,
+          rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
+          global_row, global_col, global_row, global_col);
+    }
+  }
 }
 
-}  // namespace gemmlowp
-
-#ifdef GEMMLOWP_NEON
-#include "unpack_neon.h"
-#endif
+}  // end namespace gemmlowp
 
 #endif  // GEMMLOWP_INTERNAL_UNPACK_H_
diff --git a/internal/unpack_neon.h b/internal/unpack_neon.h
index 394f10a..5c9e76a 100644
--- a/internal/unpack_neon.h
+++ b/internal/unpack_neon.h
@@ -73,12 +73,17 @@
                         PackedResultType, LhsOffset, RhsOffset,
                         OutputPipelineType> {
   typedef MatrixMap<OutputScalar, MapOrder::ColMajor> ResultBlockType;
-  static void Unpack(ResultBlockType* dst, const PackedResultType& src,
-                     int depth, const std::int32_t* lhs_sums_of_each_slice,
+  static void Unpack(ResultBlockType* dst, const MatrixBlockBounds& dst_block,
+                     const PackedResultType& src, int depth,
+                     const std::int32_t* lhs_sums_of_each_slice,
                      const std::int32_t* rhs_sums_of_each_slice,
                      const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
                      const OutputPipelineType& output_pipeline) {
     ScopedProfilingLabel label("optimized path (NEON)");
+    assert(dst_block.start_row >= 0);
+    assert(dst_block.start_row + dst_block.rows <= dst->rows());
+    assert(dst_block.start_col >= 0);
+    assert(dst_block.start_col + dst_block.cols <= dst->cols());
     const int kLhsBits = BitDepthParams::LhsBitDepth::kBits;
     const int kRhsBits = BitDepthParams::RhsBitDepth::kBits;
     const std::int32_t kLhsMax = (1 << kLhsBits) - 1;
@@ -91,16 +96,18 @@
     OutputPipelineExecutor<OutputPipelineType, NEONFragmentInt32x16x1>
         output_pipeline_executor_int32x16x1(output_pipeline);
 
-    for (int c = 0; c < dst->cols(); c++) {
+    for (int c = 0; c < dst_block.cols; c++) {
+      int c_dst = c + dst_block.start_col;
       const std::int32_t* src_ptr = src_map.data(0, c);
       const std::int32_t* sums_of_each_slice_ptr = lhs_sums_of_each_slice;
-      auto lhs_offset_iter = const_iterator(lhs_offset);
-      const std::int32_t rhs_offset_c = rhs_offset(c);
+      auto lhs_offset_iter = const_iterator(lhs_offset, dst_block.start_row);
+      const std::int32_t rhs_offset_c = rhs_offset(c_dst);
       const std::int32_t rhs_sums_of_each_slice_c = rhs_sums_of_each_slice[c];
 
       // Handle 16 values at once for higher performance
-      int dst_rows_aligned16 = RoundDown<16>(dst->rows());
+      int dst_rows_aligned16 = RoundDown<16>(dst_block.rows);
       for (int r = 0; r < dst_rows_aligned16; r += 16) {
+        int r_dst = r + dst_block.start_row;
         // Compute the sum of the 4 terms,
         //   q = term_xx + term_x1 + term_1x_plus_term_11
         // Refer to the generic code in unpack.h.
@@ -144,12 +151,13 @@
                                vaddq_s32(term_1x[i], term_11[i]));
         }
         NEONFragmentInt32x16x1 f(q);
-        output_pipeline_executor_int32x16x1.Execute(f, dst, r, c);
+        output_pipeline_executor_int32x16x1.Execute(f, dst, r_dst, c_dst);
       }
       // We have finished handling groups of 16 entries at once; now
       // try to handle 4 entries at once.
-      int dst_rows_aligned4 = RoundDown<4>(dst->rows());
+      int dst_rows_aligned4 = RoundDown<4>(dst_block.rows);
       for (int r = dst_rows_aligned16; r < dst_rows_aligned4; r += 4) {
+        int r_dst = r + dst_block.start_row;
         // Compute the sum of the 4 terms,
         //   q = term_xx + term_x1 + term_1x_plus_term_11
         // Refer to the generic code in unpack.h.
@@ -173,15 +181,17 @@
         int32x4_t q = vaddq_s32(vaddq_s32(term_xx, term_x1),
                                 vaddq_s32(term_1x, term_11));
         NEONFragmentInt32x4x1 f(q);
-        output_pipeline_executor_int32x4x1.Execute(f, dst, r, c);
+        output_pipeline_executor_int32x4x1.Execute(f, dst, r_dst, c_dst);
       }
       // We have finished handling 4 entries at once; now handle
       // remaining entries one by one. This scalar code is similar
       // to the code in unpack.h, see comments there.
-      for (int r = dst_rows_aligned4; r < dst->rows(); r++) {
+      for (int r = dst_rows_aligned4; r < dst_block.rows; r++) {
+        int r_dst = r + dst_block.start_row;
         const std::int32_t raw_xx = src_map(r, c);
         const std::int32_t raw_x1 = lhs_sums_of_each_slice[r] * rhs_offset_c;
-        const std::int32_t raw_1x = rhs_sums_of_each_slice_c * lhs_offset(r);
+        const std::int32_t raw_1x =
+            rhs_sums_of_each_slice_c * lhs_offset(r_dst);
         const std::int32_t term_xx =
             RoundingMultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>(
                 raw_xx);
@@ -191,7 +201,7 @@
             RoundingMultiplyByConstantFraction<255, kRhsMax>(raw_1x);
         const std::int32_t term_11 = lhs_offset(r) * rhs_offset(c) * depth;
         FragmentInt32x1x1 sum = term_xx + term_x1 + term_1x + term_11;
-        output_pipeline_executor_int32x1x1.Execute(sum, dst, r, c);
+        output_pipeline_executor_int32x1x1.Execute(sum, dst, r_dst, c_dst);
       }
     }
   }