blob: 7d76550947f7ebc3567f1ac81278f5a8c059425d [file] [log] [blame]
#include "./fused_nbit_rowwise_conversion.h"
#include <c10/util/Half.h>
#include <algorithm>
#include <cmath>
#include "common.h"
#ifdef USE_FBGEMM
#include "fbgemm/QuantUtils.h"
#endif
namespace caffe2 {
void FloatToFused8BitRowwiseQuantized__base(
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
constexpr float kEpsilon = 1e-8f;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int output_columns = input_columns + 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + input_columns);
float minimum_element =
*std::min_element(input_row, input_row + input_columns);
float maximum_element =
*std::max_element(input_row, input_row + input_columns);
float range = maximum_element - minimum_element;
output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (std::size_t col = 0; col < input_columns; ++col) {
output_row[col] =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
}
}
}
void Fused8BitRowwiseQuantizedToFloat__base(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int output_columns = input_columns - 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
float* output_row = output + row * output_columns;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (std::size_t col = 0; col < output_columns; ++col) {
output_row[col] =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
}
}
}
void FloatToFused8BitRowwiseQuantized(
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
#ifdef USE_FBGEMM
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
input, input_rows, input_columns, output);
#else
FloatToFused8BitRowwiseQuantized__base(
input, input_rows, input_columns, output);
#endif
}
void Fused8BitRowwiseQuantizedToFloat(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
#ifdef USE_FBGEMM
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<float>(
input, input_rows, input_columns, output);
#else
Fused8BitRowwiseQuantizedToFloat__base(
input, input_rows, input_columns, output);
#endif
}
void FloatToFusedNBitRowwiseQuantizedSBHalf__base(
int bit_rate,
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
int num_elem_per_byte = 8 / bit_rate;
int output_columns =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
(input_columns + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(at::Half);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
at::Half* output_row_scale_bias = reinterpret_cast<at::Half*>(
output_row +
(input_columns + num_elem_per_byte - 1) / num_elem_per_byte);
float minimum_element =
*std::min_element(input_row, input_row + input_columns);
float maximum_element =
*std::max_element(input_row, input_row + input_columns);
minimum_element = static_cast<at::Half>(minimum_element);
const float range = maximum_element - minimum_element;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
at::Half scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1);
if (scale == 0) {
// Corner case handling when maximum_element == minimum_element
// Any scale would work because X - minimum_element will be 0 for all X
scale = 1.0f;
}
float inverse_scale = 1.0f / scale;
if (std::isinf(inverse_scale)) {
scale = 1.0f;
inverse_scale = 1.0f;
}
output_row_scale_bias[0] = scale;
output_row_scale_bias[1] = minimum_element;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (std::size_t col = 0; col < input_columns; ++col) {
float X = input_row[col];
std::uint8_t quantized = std::max(
0,
std::min<int>(
std::lrintf((X - minimum_element) * inverse_scale),
(1 << bit_rate) - 1));
if (col % num_elem_per_byte == 0) {
output_row[col / num_elem_per_byte] = quantized;
} else {
output_row[col / num_elem_per_byte] |=
(quantized << ((col % num_elem_per_byte) * bit_rate));
}
}
}
}
void FusedNBitRowwiseQuantizedSBHalfToFloat__base(
int bit_rate,
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
int num_elem_per_byte = 8 / bit_rate;
int output_columns =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
(input_columns - 2 * sizeof(at::Half)) * num_elem_per_byte;
for (std::size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const at::Half* input_row_scale_bias = reinterpret_cast<const at::Half*>(
input_row +
(output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
float scale = input_row_scale_bias[0];
float bias = input_row_scale_bias[1];
float* output_row = output + row * output_columns;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (std::size_t col = 0; col < output_columns; ++col) {
std::uint8_t quantized = input_row[col / num_elem_per_byte];
quantized >>= (col % num_elem_per_byte) * bit_rate;
quantized &= (1 << bit_rate) - 1;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
output_row[col] = scale * quantized + bias;
}
}
}
void FloatToFusedNBitRowwiseQuantizedSBHalf(
int bit_rate,
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
#ifdef USE_FBGEMM
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
bit_rate, input, input_rows, input_columns, output);
#else
FloatToFusedNBitRowwiseQuantizedSBHalf__base(
bit_rate, input, input_rows, input_columns, output);
#endif
}
void FusedNBitRowwiseQuantizedSBHalfToFloat(
int bit_rate,
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
#ifdef USE_FBGEMM
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<float>(
bit_rate, input, input_rows, input_columns, output);
#else
FusedNBitRowwiseQuantizedSBHalfToFloat__base(
bit_rate, input, input_rows, input_columns, output);
#endif
}
} // namespace caffe2