Clamp NC operator for S8 data type

- New API functions: xnn_create_clamp_nc_s8 and xnn_setup_clamp_nc_s8
- Unit tests

PiperOrigin-RevId: 391216240
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 24dd917..c9a76f0 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -2451,6 +2451,22 @@
 
 #ifndef XNN_NO_S8_OPERATORS
 
+enum xnn_status xnn_create_clamp_nc_s8(
+  size_t channels,
+  size_t input_stride,
+  size_t output_stride,
+  int8_t output_min,
+  int8_t output_max,
+  uint32_t flags,
+  xnn_operator_t* clamp_op_out);
+
+enum xnn_status xnn_setup_clamp_nc_s8(
+  xnn_operator_t clamp_op,
+  size_t batch_size,
+  const int8_t* input,
+  int8_t* output,
+  pthreadpool_t threadpool);
+
 enum xnn_status xnn_create_max_pooling2d_nhwc_s8(
   uint32_t input_padding_top,
   uint32_t input_padding_right,
diff --git a/src/init.c b/src/init.c
index 003d89b..213797f 100644
--- a/src/init.c
+++ b/src/init.c
@@ -270,6 +270,11 @@
     #ifndef XNN_NO_S8_OPERATORS
       init_flags |= XNN_INIT_FLAG_S8;
 
+      xnn_params.s8.clamp = (struct vunary_parameters) {
+        .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__neon_x64,
+        .init.s8_minmax = xnn_init_s8_minmax_neon_params,
+        .element_tile = 64,
+      };
       xnn_params.s8.maxpool = (struct maxpool_parameters) {
         .ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__neon_c16,
         .init.s8 = xnn_init_s8_minmax_neon_params,
@@ -740,6 +745,11 @@
     #ifndef XNN_NO_S8_OPERATORS
       init_flags |= XNN_INIT_FLAG_S8;
 
+      xnn_params.s8.clamp = (struct vunary_parameters) {
+        .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__scalar_x4,
+        .init.s8_minmax = xnn_init_s8_minmax_neon_params,
+        .element_tile = 4,
+      };
       xnn_params.s8.maxpool = (struct maxpool_parameters) {
         .ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__scalar_c1,
         .init.s8 = xnn_init_s8_minmax_scalar_params,
@@ -1540,6 +1550,11 @@
   #ifndef XNN_NO_S8_OPERATORS
     init_flags |= XNN_INIT_FLAG_S8;
 
+    xnn_params.s8.clamp = (struct vunary_parameters) {
+      .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__neon_x64,
+      .init.s8_minmax = xnn_init_s8_minmax_neon_params,
+      .element_tile = 64,
+    };
     xnn_params.s8.maxpool = (struct maxpool_parameters) {
       .ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__neon_c16,
       .init.s8 = xnn_init_s8_minmax_neon_params,
@@ -2639,6 +2654,11 @@
     init_flags |= XNN_INIT_FLAG_S8;
 
     if (cpuinfo_has_x86_sse4_1()) {
+      xnn_params.s8.clamp = (struct vunary_parameters) {
+        .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__sse41_x64,
+        .init.s8_minmax = xnn_init_s8_minmax_sse4_params,
+        .element_tile = 64,
+      };
       xnn_params.s8.maxpool = (struct maxpool_parameters) {
         .ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__sse41_c16,
         .init.s8 = xnn_init_s8_minmax_sse4_params,
@@ -2646,6 +2666,11 @@
         .qr = 8,
       };
     } else {
+      xnn_params.s8.clamp = (struct vunary_parameters) {
+        .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__sse2_x64,
+        .init.s8_minmax = xnn_init_s8_minmax_sse2_params,
+        .element_tile = 64,
+      };
       xnn_params.s8.maxpool = (struct maxpool_parameters) {
         .ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__sse2_c16,
         .init.s8 = xnn_init_s8_minmax_sse2_params,
@@ -3332,6 +3357,11 @@
   #ifndef XNN_NO_S8_OPERATORS
     init_flags |= XNN_INIT_FLAG_S8;
 
+    xnn_params.s8.clamp = (struct vunary_parameters) {
+      .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__wasmsimd_x64,
+      .init.s8_minmax = xnn_init_s8_minmax_wasmsimd_params,
+      .element_tile = 64,
+    };
     xnn_params.s8.maxpool = (struct maxpool_parameters) {
       .ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__wasmsimd_c16,
       .init.s8 = xnn_init_s8_minmax_wasmsimd_params,
@@ -3345,9 +3375,9 @@
     init_flags |= XNN_INIT_FLAG_U8;
 
     xnn_params.u8.clamp = (struct vunary_parameters) {
-      .ukernel = (xnn_univector_ukernel_function) xnn_u8_vclamp_ukernel__scalar_x4,
-      .init.u8_minmax = xnn_init_u8_minmax_scalar_params,
-      .element_tile = 4,
+      .ukernel = (xnn_univector_ukernel_function) xnn_u8_vclamp_ukernel__wasmsimd_x64,
+      .init.u8_minmax = xnn_init_u8_minmax_wasmsimd_params,
+      .element_tile = 64,
     };
     xnn_params.u8.maxpool = (struct maxpool_parameters) {
       .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_minmax_ukernel_9p8x__wasmsimd_c16,
@@ -3974,6 +4004,11 @@
   #ifndef XNN_NO_S8_OPERATORS
     init_flags |= XNN_INIT_FLAG_S8;
 
+    xnn_params.s8.clamp = (struct vunary_parameters) {
+      .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__scalar_x4,
+      .init.s8_minmax = xnn_init_s8_minmax_scalar_params,
+      .element_tile = 4,
+    };
     xnn_params.s8.maxpool = (struct maxpool_parameters) {
       .ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__scalar_c1,
       .init.s8 = xnn_init_s8_minmax_scalar_params,
diff --git a/src/operator-strings.c b/src/operator-strings.c
index fcab61d..30e8fe9 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -44,6 +44,8 @@
       return "Channel Shuffle (NC, X32)";
     case xnn_operator_type_clamp_nc_f32:
       return "Clamp (NC, F32)";
+    case xnn_operator_type_clamp_nc_s8:
+      return "Clamp (NC, S8)";
     case xnn_operator_type_clamp_nc_u8:
       return "Clamp (NC, U8)";
     case xnn_operator_type_constant_pad_nd_x8:
diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c
index bfbd5d8..f76fe58 100644
--- a/src/operators/unary-elementwise-nc.c
+++ b/src/operators/unary-elementwise-nc.c
@@ -148,6 +148,34 @@
   return xnn_status_success;
 }
 
+enum xnn_status xnn_create_clamp_nc_s8(
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    int8_t output_min,
+    int8_t output_max,
+    uint32_t flags,
+    xnn_operator_t* clamp_op_out)
+{
+  if (output_min >= output_max) {
+    xnn_log_error(
+      "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
+      xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8), output_min, output_max);
+    return xnn_status_invalid_parameter;
+  }
+
+  union xnn_s8_minmax_params params;
+  if (xnn_params.s8.clamp.init.s8_minmax != NULL) {
+    xnn_params.s8.clamp.init.s8_minmax(&params, output_min, output_max);
+  }
+  return create_unary_elementwise_nc(
+    channels, input_stride, output_stride, flags,
+    &params, sizeof(params),
+    xnn_operator_type_clamp_nc_s8,
+    xnn_params.s8.clamp.ukernel,
+    clamp_op_out);
+}
+
 enum xnn_status xnn_create_clamp_nc_u8(
     size_t channels,
     size_t input_stride,
@@ -549,6 +577,28 @@
     &ceiling_op->params.f32_rnd, sizeof(ceiling_op->params.f32_rnd));
 }
 
+enum xnn_status xnn_setup_clamp_nc_s8(
+    xnn_operator_t clamp_op,
+    size_t batch_size,
+    const int8_t* input,
+    int8_t* output,
+    pthreadpool_t threadpool)
+{
+  if (clamp_op->type != xnn_operator_type_clamp_nc_s8) {
+    xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
+      xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8),
+      xnn_operator_type_to_string(clamp_op->type));
+    return xnn_status_invalid_parameter;
+  }
+  clamp_op->state = xnn_run_state_invalid;
+
+  return setup_unary_elementwise_nc(
+    clamp_op,
+    batch_size, input, output,
+    0 /* log2(sizeof(int8_t)) */,
+    &clamp_op->params.s8_minmax, sizeof(clamp_op->params.s8_minmax));
+}
+
 enum xnn_status xnn_setup_clamp_nc_u8(
     xnn_operator_t clamp_op,
     size_t batch_size,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index e579e68..60738e2 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -44,6 +44,7 @@
   xnn_operator_type_channel_shuffle_nc_x8,
   xnn_operator_type_channel_shuffle_nc_x32,
   xnn_operator_type_clamp_nc_f32,
+  xnn_operator_type_clamp_nc_s8,
   xnn_operator_type_clamp_nc_u8,
   xnn_operator_type_ceiling_nc_f32,
   xnn_operator_type_constant_pad_nd_x8,
diff --git a/test/clamp-nc.cc b/test/clamp-nc.cc
index 63a7bda..4ad035c 100644
--- a/test/clamp-nc.cc
+++ b/test/clamp-nc.cc
@@ -11,6 +11,88 @@
 #include "clamp-operator-tester.h"
 
 
+TEST(CLAMP_NC_S8, unit_batch) {
+  for (size_t channels = 1; channels < 100; channels++) {
+    ClampOperatorTester()
+      .batch_size(1)
+      .channels(channels)
+      .iterations(3)
+      .TestS8();
+  }
+}
+
+TEST(CLAMP_NC_S8, unit_batch_with_qmin) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (uint8_t qmin = 1; qmin < 255; qmin++) {
+      ClampOperatorTester()
+        .batch_size(1)
+        .channels(channels)
+        .qmin(qmin)
+        .qmax(255)
+        .iterations(3)
+        .TestS8();
+    }
+  }
+}
+
+TEST(CLAMP_NC_S8, unit_batch_with_qmax) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (uint8_t qmax = 1; qmax < 255; qmax++) {
+      ClampOperatorTester()
+        .batch_size(1)
+        .channels(channels)
+        .qmin(0)
+        .qmax(qmax)
+        .iterations(3)
+        .TestS8();
+    }
+  }
+}
+
+TEST(CLAMP_NC_S8, small_batch) {
+  for (size_t channels = 1; channels < 100; channels++) {
+    ClampOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .iterations(3)
+      .TestS8();
+  }
+}
+
+TEST(CLAMP_NC_S8, small_batch_with_input_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    ClampOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .iterations(3)
+      .TestS8();
+  }
+}
+
+TEST(CLAMP_NC_S8, small_batch_with_output_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    ClampOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .output_stride(117)
+      .iterations(3)
+      .TestS8();
+  }
+}
+
+TEST(CLAMP_NC_S8, small_batch_with_input_and_output_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    ClampOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .output_stride(117)
+      .iterations(3)
+      .TestS8();
+  }
+}
+
 TEST(CLAMP_NC_U8, unit_batch) {
   for (size_t channels = 1; channels < 100; channels++) {
     ClampOperatorTester()
diff --git a/test/clamp-operator-tester.h b/test/clamp-operator-tester.h
index d067bd9..8623e00 100644
--- a/test/clamp-operator-tester.h
+++ b/test/clamp-operator-tester.h
@@ -110,6 +110,69 @@
     return this->iterations_;
   }
 
+  void TestS8() const {
+    std::random_device random_device;
+    auto rng = std::mt19937(random_device());
+    auto i8rng = std::bind(
+      std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
+      std::ref(rng));
+
+    std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
+      (batch_size() - 1) * input_stride() + channels());
+    std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels());
+    std::vector<int8_t> output_ref(batch_size() * channels());
+    for (size_t iteration = 0; iteration < iterations(); iteration++) {
+      std::generate(input.begin(), input.end(), std::ref(i8rng));
+      std::fill(output.begin(), output.end(), INT8_C(0xA5));
+
+      // Compute reference results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          const int8_t x = input[i * input_stride() + c];
+          const int8_t y = std::min(std::max(x, int8_t(qmin() - 0x80)), int8_t(qmax() - 0x80));
+          output_ref[i * channels() + c] = y;
+        }
+      }
+
+      // Create, setup, run, and destroy Clamp operator.
+      ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+      xnn_operator_t clamp_op = nullptr;
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_create_clamp_nc_s8(
+          channels(), input_stride(), output_stride(),
+          int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
+          0, &clamp_op));
+      ASSERT_NE(nullptr, clamp_op);
+
+      // Smart pointer to automatically delete clamp_op.
+      std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator);
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_setup_clamp_nc_s8(
+          clamp_op,
+          batch_size(),
+          input.data(), output.data(),
+          nullptr /* thread pool */));
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_run_operator(clamp_op, nullptr /* thread pool */));
+
+      // Verify results .
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80))
+            << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
+          ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80))
+            << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
+          ASSERT_EQ(int32_t(output_ref[i * channels() + c]), int32_t(output[i * output_stride() + c]))
+            << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels()
+            << ", qmin = " << int32_t(qmin() - 0x80) << ", qmax = " << int32_t(qmax() - 0x80);
+        }
+      }
+    }
+  }
+
   void TestU8() const {
     std::random_device random_device;
     auto rng = std::mt19937(random_device());