Fully Connected operator for QS8 datatype
PiperOrigin-RevId: 366203468
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 39c8617..c55fb16 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -1998,6 +1998,30 @@
int8_t* output,
pthreadpool_t threadpool);
+enum xnn_status xnn_create_fully_connected_nc_qs8(
+ size_t input_channels,
+ size_t output_channels,
+ size_t input_stride,
+ size_t output_stride,
+ int8_t input_zero_point,
+ float input_scale,
+ float kernel_scale,
+ const int8_t* kernel,
+ const int32_t* bias,
+ int8_t output_zero_point,
+ float output_scale,
+ int8_t output_min,
+ int8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* fully_connected_op_out);
+
+enum xnn_status xnn_setup_fully_connected_nc_qs8(
+ xnn_operator_t fully_connected_op,
+ size_t batch_size,
+ const int8_t* input,
+ int8_t* output,
+ pthreadpool_t threadpool);
+
enum xnn_status xnn_create_global_average_pooling_nwc_qs8(
size_t channels,
size_t input_stride,
diff --git a/src/operator-strings.c b/src/operator-strings.c
index 1af6157..f4b9fa7 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -74,6 +74,8 @@
return "Floor (NC, F32)";
case xnn_operator_type_fully_connected_nc_f32:
return "Fully Connected (NC, F32)";
+ case xnn_operator_type_fully_connected_nc_qs8:
+ return "Fully Connected (NC, QS8)";
case xnn_operator_type_fully_connected_nc_qu8:
return "Fully Connected (NC, QU8)";
case xnn_operator_type_global_average_pooling_nwc_f16:
diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c
index ff4485a..3c0c352 100644
--- a/src/operators/fully-connected-nc.c
+++ b/src/operators/fully-connected-nc.c
@@ -41,6 +41,7 @@
size_t params_size,
const struct gemm_parameters* gemm_parameters,
const struct gemm_fused_ukernels* gemm_ukernels,
+ uint32_t datatype_init_flags,
enum xnn_operator_type operator_type,
xnn_operator_t* fully_connected_op_out)
{
@@ -53,6 +54,15 @@
goto error;
}
+ status = xnn_status_unsupported_hardware;
+
+ if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
+ xnn_log_error(
+ "failed to create %s operator: operations on data type are not supported",
+ xnn_operator_type_to_string(operator_type));
+ goto error;
+ }
+
status = xnn_status_invalid_parameter;
if (input_channels == 0) {
@@ -160,6 +170,7 @@
size_t batch_size,
const void* input,
void* output,
+ uint32_t datatype_init_flags,
uint32_t log2_input_element_size,
uint32_t log2_filter_element_size,
uint32_t bias_element_size,
@@ -309,10 +320,87 @@
&packing_params, kernel_zero_point /* packed weights padding byte */,
¶ms, sizeof(params),
&xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
+ XNN_INIT_FLAG_QU8,
xnn_operator_type_fully_connected_nc_qu8,
fully_connected_op_out);
}
+enum xnn_status xnn_create_fully_connected_nc_qs8(
+ size_t input_channels,
+ size_t output_channels,
+ size_t input_stride,
+ size_t output_stride,
+ int8_t input_zero_point,
+ float input_scale,
+ float kernel_scale,
+ const int8_t* kernel,
+ const int32_t* bias,
+ int8_t output_zero_point,
+ float output_scale,
+ int8_t output_min,
+ int8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* fully_connected_op_out)
+{
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
+ xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), input_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
+ xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), kernel_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
+ xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_scale);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
+ xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_min, output_max);
+ return xnn_status_invalid_parameter;
+ }
+
+ const float requantization_scale = input_scale * kernel_scale / output_scale;
+ if (requantization_scale >= 1.0f) {
+ xnn_log_error(
+ "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
+ "requantization scale %.7g is greater or equal to 1.0",
+ xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8),
+ input_scale, kernel_scale, output_scale, requantization_scale);
+ return xnn_status_unsupported_parameter;
+ }
+
+ const union xnn_qs8_gemm_params params = xnn_init_qs8_gemm_params(
+ requantization_scale, output_zero_point, output_min, output_max);
+ const struct xnn_qs8_packing_params packing_params = {
+ .input_zero_point = input_zero_point,
+ };
+ return create_fully_connected_nc(
+ input_channels, output_channels,
+ input_stride, output_stride,
+ kernel, bias, flags,
+ 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
+ sizeof(int32_t) /* sizeof(bias element) */,
+ (xnn_pack_gemm_io_w_function) xnn_pack_qs8_gemm_io_w,
+ (xnn_pack_gemm_goi_w_function) xnn_pack_qs8_gemm_goi_w,
+ &packing_params, 0 /* packed weights padding byte */,
+ ¶ms, sizeof(params),
+ &xnn_params.qs8.gemm, &xnn_params.qs8.gemm.minmax,
+ XNN_INIT_FLAG_QS8,
+ xnn_operator_type_fully_connected_nc_qs8,
+ fully_connected_op_out);
+}
+
enum xnn_status xnn_create_fully_connected_nc_f32(
size_t input_channels,
size_t output_channels,
@@ -364,6 +452,7 @@
NULL /* packing params */, 0 /* packed weights padding byte */,
¶ms, sizeof(params),
&xnn_params.f32.gemm, gemm_ukernels,
+ XNN_INIT_FLAG_F32,
xnn_operator_type_fully_connected_nc_f32,
fully_connected_op_out);
}
@@ -386,6 +475,7 @@
fully_connected_op,
batch_size,
input, output,
+ XNN_INIT_FLAG_QU8,
0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
sizeof(int32_t) /* sizeof(bias element) */,
@@ -395,6 +485,34 @@
pthreadpool_get_threads_count(threadpool));
}
+enum xnn_status xnn_setup_fully_connected_nc_qs8(
+ xnn_operator_t fully_connected_op,
+ size_t batch_size,
+ const int8_t* input,
+ int8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_qs8) {
+ xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
+ xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8),
+ xnn_operator_type_to_string(fully_connected_op->type));
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_fully_connected_nc(
+ fully_connected_op,
+ batch_size,
+ input, output,
+ XNN_INIT_FLAG_QS8,
+ 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
+ 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
+ sizeof(int32_t) /* sizeof(bias element) */,
+ 0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
+ &fully_connected_op->params.qs8_gemm,
+ sizeof(fully_connected_op->params.qs8_gemm),
+ pthreadpool_get_threads_count(threadpool));
+}
+
enum xnn_status xnn_setup_fully_connected_nc_f32(
xnn_operator_t fully_connected_op,
size_t batch_size,
@@ -413,6 +531,7 @@
fully_connected_op,
batch_size,
input, output,
+ XNN_INIT_FLAG_F32,
2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
sizeof(float) /* sizeof(bias element) */,
diff --git a/src/packing.c b/src/packing.c
index 24ade5b..2e232a4 100644
--- a/src/packing.c
+++ b/src/packing.c
@@ -426,6 +426,53 @@
}
}
+void xnn_pack_qs8_gemm_io_w(
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ size_t sr,
+ const int8_t* k,
+ const int32_t* b,
+ void* packed_w,
+ const struct xnn_qs8_packing_params* params)
+{
+ assert(sr == 1);
+ const int32_t izp = (int32_t) params->input_zero_point;
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ if XNN_LIKELY(b != NULL) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset];
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ } else {
+ size_t n = nr_block_size;
+ do {
+ *((int32_t*) packed_w) = 0;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ } while (--n != 0);
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const int8_t kv = k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+ ksum += (int32_t) kv;
+ *((int8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int8_t));
+ }
+ packed_b[nr_block_offset] -= ksum * izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(int8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int8_t));
+ }
+ }
+}
+
void xnn_pack_f32_conv_goki_w(
size_t g,
size_t nc,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 765bc8c..c252d94 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -59,6 +59,7 @@
xnn_operator_type_divide_nd_f32,
xnn_operator_type_elu_nc_f32,
xnn_operator_type_fully_connected_nc_f32,
+ xnn_operator_type_fully_connected_nc_qs8,
xnn_operator_type_fully_connected_nc_qu8,
xnn_operator_type_floor_nc_f32,
xnn_operator_type_global_average_pooling_nwc_f16,
diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h
index ae3eb6d..ba6c237 100644
--- a/src/xnnpack/pack.h
+++ b/src/xnnpack/pack.h
@@ -147,6 +147,17 @@
void* packed_w,
const struct xnn_qu8_packing_params* params);
+XNN_INTERNAL void xnn_pack_qs8_gemm_io_w(
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ size_t sr,
+ const int8_t* k,
+ const int32_t* b,
+ void* packed_w,
+ const struct xnn_qs8_packing_params* params);
+
typedef void (*xnn_pack_conv_goki_w_function)(
size_t g,
diff --git a/test/fully-connected-nc.cc b/test/fully-connected-nc.cc
index a25a2e9..28a0d69 100644
--- a/test/fully-connected-nc.cc
+++ b/test/fully-connected-nc.cc
@@ -11,6 +11,144 @@
#include "fully-connected-operator-tester.h"
+TEST(FULLY_CONNECTED_NC_QS8, unit_batch) {
+ FullyConnectedOperatorTester()
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, unit_batch_with_qmin) {
+ FullyConnectedOperatorTester()
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .qmin(128)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, unit_batch_with_qmax) {
+ FullyConnectedOperatorTester()
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .qmax(128)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, unit_batch_with_input_stride) {
+ FullyConnectedOperatorTester()
+ .batch_size(1)
+ .input_channels(23)
+ .input_stride(28)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, unit_batch_with_output_stride) {
+ FullyConnectedOperatorTester()
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .output_stride(29)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, unit_batch_transpose_weights) {
+ FullyConnectedOperatorTester()
+ .transpose_weights(true)
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, unit_batch_without_bias) {
+ FullyConnectedOperatorTester()
+ .has_bias(false)
+ .batch_size(1)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, small_batch) {
+ FullyConnectedOperatorTester()
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, small_batch_with_qmin) {
+ FullyConnectedOperatorTester()
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .qmin(128)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, small_batch_with_qmax) {
+ FullyConnectedOperatorTester()
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .qmax(128)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, small_batch_with_input_stride) {
+ FullyConnectedOperatorTester()
+ .batch_size(12)
+ .input_channels(23)
+ .input_stride(28)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, small_batch_with_output_stride) {
+ FullyConnectedOperatorTester()
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .output_stride(29)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, small_batch_transpose_weights) {
+ FullyConnectedOperatorTester()
+ .transpose_weights(true)
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
+TEST(FULLY_CONNECTED_NC_QS8, small_batch_without_bias) {
+ FullyConnectedOperatorTester()
+ .has_bias(false)
+ .batch_size(12)
+ .input_channels(23)
+ .output_channels(19)
+ .iterations(3)
+ .TestQS8();
+}
+
TEST(FULLY_CONNECTED_NC_QU8, unit_batch) {
FullyConnectedOperatorTester()
.batch_size(1)
diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h
index 648de5d..8cd8033 100644
--- a/test/fully-connected-operator-tester.h
+++ b/test/fully-connected-operator-tester.h
@@ -129,6 +129,125 @@
return this->iterations_;
}
+ void TestQS8() const {
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
+ auto i8rng = std::bind(std::uniform_int_distribution<int32_t>(
+ -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()), rng);
+
+ std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
+ (batch_size() - 1) * input_stride() + input_channels());
+ std::vector<int8_t> kernel(output_channels() * input_channels());
+ std::vector<int32_t> bias(output_channels());
+ std::vector<int8_t> output((batch_size() - 1) * output_stride() + output_channels());
+ std::vector<int32_t> accumulators(batch_size() * output_channels());
+ std::vector<double> output_ref(batch_size() * output_channels());
+
+ const int8_t input_zero_point = 127;
+
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(input.begin(), input.end(), std::ref(i8rng));
+ std::generate(kernel.begin(), kernel.end(), std::ref(i8rng));
+ std::generate(bias.begin(), bias.end(), std::ref(i32rng));
+ std::fill(output.begin(), output.end(), 0xA5);
+
+ // Compute reference results, without renormalization.
+ if (has_bias()) {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ accumulators[i * output_channels() + oc] = bias[oc];
+ }
+ }
+ } else {
+ std::fill(accumulators.begin(), accumulators.end(), 0);
+ }
+ if (transpose_weights()) {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ accumulators[i * output_channels() + oc] +=
+ (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+ int32_t(kernel[ic * output_channels() + oc]);
+ }
+ }
+ }
+ } else {
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t oc = 0; oc < output_channels(); oc++) {
+ for (size_t ic = 0; ic < input_channels(); ic++) {
+ accumulators[i * output_channels() + oc] +=
+ (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
+ int32_t(kernel[oc * input_channels() + ic]);
+ }
+ }
+ }
+ }
+
+ // Compute renormalization parameters.
+ const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
+ const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
+
+ const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
+ const int8_t output_zero_point = int8_t(std::max(std::min(
+ lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
+ long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
+
+ // Renormalize reference results.
+ std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
+ [this, output_scale, output_zero_point](int32_t x) -> double {
+ return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
+ });
+
+ // Create, setup, run, and destroy Fully Connected operator.
+ ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+ xnn_operator_t fully_connected_op = nullptr;
+
+ const xnn_status status = xnn_create_fully_connected_nc_qs8(
+ input_channels(), output_channels(),
+ input_stride(), output_stride(),
+ input_zero_point, 1.0f /* input scale */,
+ 1.0f /* kernel scale */,
+ kernel.data(), has_bias() ? bias.data() : nullptr,
+ output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
+ transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
+ &fully_connected_op);
+ if (status == xnn_status_unsupported_hardware) {
+ GTEST_SKIP();
+ }
+ ASSERT_EQ(xnn_status_success, status);
+ ASSERT_NE(nullptr, fully_connected_op);
+
+ // Smart pointer to automatically delete fully_connected_op.
+ std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_fully_connected_nc_qs8(
+ fully_connected_op,
+ batch_size(),
+ input.data(), output.data(),
+ nullptr /* thread pool */));
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
+
+ // Verify results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t c = 0; c < output_channels(); c++) {
+ ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80))
+ << "batch index = " << i << ", channel = " << c;
+ ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80))
+ << "batch index = " << i << ", channel = " << c;
+ ASSERT_NEAR(
+ output_ref[i * output_channels() + c],
+ double(output[i * output_stride() + c]) - double(output_zero_point),
+ 0.9)
+ << "batch index = " << i << ", channel = " << c;
+ }
+ }
+ }
+ }
+
void TestQU8() const {
std::random_device random_device;
auto rng = std::mt19937(random_device());
@@ -203,8 +322,7 @@
ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
xnn_operator_t fully_connected_op = nullptr;
- ASSERT_EQ(xnn_status_success,
- xnn_create_fully_connected_nc_qu8(
+ const xnn_status status = xnn_create_fully_connected_nc_qu8(
input_channels(), output_channels(),
input_stride(), output_stride(),
input_zero_point, 1.0f /* input scale */,
@@ -212,7 +330,12 @@
kernel.data(), has_bias() ? bias.data() : nullptr,
output_zero_point, output_scale, qmin(), qmax(),
transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
- &fully_connected_op));
+ &fully_connected_op);
+ if (status == xnn_status_unsupported_hardware) {
+ GTEST_SKIP();
+ }
+ ASSERT_EQ(xnn_status_success, status);
+ ASSERT_NE(nullptr, fully_connected_op);
// Smart pointer to automatically delete fully_connected_op.
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
@@ -310,14 +433,18 @@
ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
xnn_operator_t fully_connected_op = nullptr;
- ASSERT_EQ(xnn_status_success,
- xnn_create_fully_connected_nc_f32(
+ const xnn_status status = xnn_create_fully_connected_nc_f32(
input_channels(), output_channels(),
input_stride(), output_stride(),
kernel.data(), has_bias() ? bias.data() : nullptr,
output_min, output_max,
transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
- &fully_connected_op));
+ &fully_connected_op);
+ if (status == xnn_status_unsupported_hardware) {
+ GTEST_SKIP();
+ }
+ ASSERT_EQ(xnn_status_success, status);
+ ASSERT_NE(nullptr, fully_connected_op);
// Smart pointer to automatically delete fully_connected_op.
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);