| """Generates the specialized gemv functions.""" |
| |
| import mul_1x8_Mx8_neon |
| import mul_Nx8_Mx8_neon |
| import qnt_Nx8_neon |
| import zip_Nx8_neon |
| |
| _QUANTIZED_8BIT = 'quantized_8bit' |
| _FULL_32BIT = 'full_32bit' |
| _FULL_FLOAT = 'full_float' |
| |
| |
| class Error(Exception): |
| """Module level error.""" |
| |
| |
| class ConfigurationError(Error): |
| """Runtime configuration error.""" |
| |
| |
| def GenerateCommonTempsCountersAndConsts(emitter): |
| """Generates common gemv boilerplate variables.""" |
| emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 8') |
| emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8') |
| emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 4') |
| emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size', |
| '(padded_k + 16) * 4') |
| emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs') |
| emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch') |
| emitter.EmitDeclare('std::int32_t*', 'zipped_lhs_offsets', |
| 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k)') |
| emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_1', |
| 'scratch + padded_k + 16') |
| emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_2', |
| 'zipped_rhs_1 + zipped_chunk_size') |
| emitter.EmitNewline() |
| |
| |
| def GenerateQuantized8BitTempsCountersAndConsts(emitter): |
| """Generates all the boilerplate variables for the q8 gemm function.""" |
| GenerateCommonTempsCountersAndConsts(emitter) |
| emitter.EmitDeclare('const std::int32_t', 'const_offset', |
| 'lhs_offset * rhs_offset * k + result_offset') |
| emitter.EmitDeclare('const std::int32_t', 'rounding_offset', |
| '(1 << (shift - 1))') |
| emitter.EmitDeclare('std::int32_t*', 'temp_result', |
| 'reinterpret_cast<std::int32_t*>(' |
| 'zipped_rhs_2 + zipped_chunk_size)') |
| emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result') |
| emitter.EmitNewline() |
| |
| |
| def GenerateFullTempsCountersAndConsts(emitter, result_type): |
| """Generates all the boilerplate variables for the int32 and float gemms.""" |
| GenerateCommonTempsCountersAndConsts(emitter) |
| emitter.EmitDeclare('const std::int32_t', 'const_offset', |
| 'lhs_offset * rhs_offset * k') |
| emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result') |
| emitter.EmitNewline() |
| |
| |
| def GenerateZipVector(emitter, aligned, leftovers): |
| emitter.EmitCall( |
| zip_Nx8_neon.BuildName(1, leftovers, aligned), |
| ['lhs', 'k', 'k', 'zipped_lhs', 'rhs_offset', 0]) |
| |
| |
| def GetMul2Params(result_type): |
| params = ['zipped_lhs', 'zipped_rhs_1', 'zipped_rhs_2', 'padded_k', |
| 'mul_result_chunk'] |
| if result_type is 'float': |
| params.append('result_scale') |
| return params |
| |
| |
| def GetMulParams(result_type): |
| params = ['zipped_lhs', 'zipped_rhs_1', 'padded_k', 'mul_result_chunk', 0] |
| if result_type is 'float': |
| params.append('result_scale') |
| return params |
| |
| |
| def GenerateMulCols(emitter, result_type, lhs_add, rhs_add, aligned, cols, |
| leftovers): |
| """Emits code responsible for multiplication of one horizontal lhs strip.""" |
| emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)') |
| emitter.EmitCall( |
| zip_Nx8_neon.BuildName(4, leftovers, aligned), |
| ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset']) |
| emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') |
| |
| emitter.EmitCall( |
| zip_Nx8_neon.BuildName(4, leftovers, aligned), |
| ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset']) |
| emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') |
| |
| emitter.EmitCall( |
| mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 8), |
| GetMul2Params(result_type)) |
| |
| emitter.EmitAssignIncrement('mul_result_chunk', 8) |
| emitter.EmitCloseBracket() |
| |
| if cols > 4: |
| emitter.EmitCall( |
| zip_Nx8_neon.BuildName(4, leftovers, aligned), |
| ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset']) |
| emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') |
| |
| emitter.EmitCall( |
| zip_Nx8_neon.BuildName(cols - 4, leftovers, aligned), |
| ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset']) |
| |
| emitter.EmitCall( |
| mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, cols), |
| GetMul2Params(result_type)) |
| elif cols > 0: |
| emitter.EmitCall( |
| zip_Nx8_neon.BuildName(cols, leftovers, aligned), |
| ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset']) |
| |
| emitter.EmitCall( |
| mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 1, cols), |
| GetMulParams(result_type)) |
| |
| |
| def GenerateQuantized8BitMul(emitter, aligned, cols, leftovers): |
| """Emits code for all lhs strips & leftover rows. Quantize after mul code.""" |
| GenerateMulCols(emitter, 'int32', False, True, aligned, cols, leftovers) |
| emitter.EmitCall( |
| qnt_Nx8_neon.BuildName(1, cols, aligned), |
| ['temp_result', 'n', 0, 'zipped_lhs_offsets', 'result', 0, |
| 'multiplicative_offset', 'rounding_offset', '-shift']) |
| |
| |
| def GenerateFullMul(emitter, result_type, aligned, cols, leftovers): |
| GenerateMulCols(emitter, result_type, True, True, aligned, cols, leftovers) |
| |
| |
| def BuildName(output_type, aligned, cols, leftover): |
| name = BuildMainGemvName(output_type) + '_%d_%d' % (cols, leftover) |
| if aligned: |
| name += '_aligned' |
| return name |
| |
| |
| def GetCommonGemvParameters(): |
| return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'], |
| ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'n'], |
| ['std::int32_t', 'k'], ['std::int32_t', 'lhs_offset'], |
| ['std::int32_t', 'rhs_offset']] |
| |
| |
| def GetGemvParameters(output_type): |
| """Prepares a (type, parameter) array for the gemm functions.""" |
| params = GetCommonGemvParameters() |
| if output_type is _QUANTIZED_8BIT: |
| params += [['std::int32_t', 'result_offset'], |
| ['std::int32_t', 'multiplicative_offset'], |
| ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']] |
| elif output_type is _FULL_32BIT: |
| params += [['std::int32_t*', 'result']] |
| elif output_type is _FULL_FLOAT: |
| params += [['float', 'result_scale'], ['float*', 'result']] |
| else: |
| raise ConfigurationError('Unsupported output type: %s' % output_type) |
| return params |
| |
| |
| def GenerateGemv(emitter, output_type, aligned, cols, leftovers): |
| """Build one gemm function for given col, and depth leftovers.""" |
| emitter.EmitFunctionBeginA( |
| BuildName(output_type, aligned, cols, leftovers), |
| GetGemvParameters(output_type), 'void') |
| |
| emitter.EmitAssert('n %% 8 == %d' % cols) |
| emitter.EmitAssert('k %% 8 == %d' % leftovers) |
| |
| if output_type is _QUANTIZED_8BIT: |
| GenerateQuantized8BitTempsCountersAndConsts(emitter) |
| GenerateZipVector(emitter, aligned, leftovers) |
| GenerateQuantized8BitMul(emitter, aligned, cols, leftovers) |
| elif output_type is _FULL_32BIT: |
| GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*') |
| GenerateZipVector(emitter, aligned, leftovers) |
| GenerateFullMul(emitter, 'int32', aligned, cols, leftovers) |
| elif output_type is _FULL_FLOAT: |
| GenerateFullTempsCountersAndConsts(emitter, 'float*') |
| GenerateZipVector(emitter, aligned, leftovers) |
| GenerateFullMul(emitter, 'float', aligned, cols, leftovers) |
| else: |
| raise ConfigurationError('Unknown output type: %s' % output_type) |
| |
| emitter.EmitFunctionEnd() |
| |
| |
| def GenerateGemvCall(emitter, output_type, aligned, m_mod, leftovers): |
| emitter.EmitCall( |
| emitter.Scope('internal', |
| BuildName(output_type, aligned, m_mod, leftovers)), |
| [p for (unused_t, p) in GetGemvParameters(output_type)]) |
| |
| |
| def GenerateGemvSwitch2(emitter, output_type, aligned, n_mod): |
| """Second level of main switch, choose optimized version on depth leftover.""" |
| emitter.EmitSwitch('k % 8') |
| |
| for leftovers in range(0, 8): |
| emitter.EmitCase(leftovers) |
| emitter.PushIndent() |
| GenerateGemvCall(emitter, output_type, aligned, n_mod, leftovers) |
| emitter.EmitBreak() |
| emitter.PopIndent() |
| |
| emitter.EmitSwitchEnd() |
| |
| |
| def GenerateGemvSwitch1(emitter, output_type, aligned): |
| """First level of main switch, choose optimized version on cols leftover.""" |
| emitter.EmitSwitch('n % 8') |
| |
| for n_mod in range(0, 8): |
| emitter.EmitCase(n_mod) |
| emitter.PushIndent() |
| GenerateGemvSwitch2(emitter, output_type, aligned, n_mod) |
| emitter.EmitBreak() |
| emitter.PopIndent() |
| |
| emitter.EmitSwitchEnd() |
| |
| |
| def BuildMainGemvName(output_type): |
| if output_type is _QUANTIZED_8BIT: |
| return 'gemv_q8' |
| elif output_type is _FULL_32BIT: |
| return 'gemv_i32' |
| elif output_type is _FULL_FLOAT: |
| return 'gemv_f' |
| else: |
| raise ConfigurationError('Unsupported output type: %s' % output_type) |
| |
| |
| def GenerateMainGemvFunction(emitter, output_type): |
| """Emit high level gemv function that switches between optimized versions.""" |
| emitter.EmitFunctionBeginA( |
| BuildMainGemvName(output_type), GetGemvParameters(output_type), 'void') |
| |
| emitter.EmitDeclare('const bool', 'lhs_aligned', |
| '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)') |
| emitter.EmitDeclare('const bool', 'rhs_aligned', |
| '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)') |
| emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)') |
| |
| if output_type is _QUANTIZED_8BIT: |
| emitter.EmitDeclare('const bool', 'result_aligned', |
| '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)') |
| emitter.EmitDeclare('const bool', 'aligned', |
| 'lhs_aligned && rhs_aligned && result_aligned ' |
| '&& k_aligned') |
| else: |
| emitter.EmitDeclare('const bool', 'aligned', |
| 'lhs_aligned && rhs_aligned && k_aligned') |
| |
| emitter.EmitIf('aligned') |
| GenerateGemvSwitch1(emitter, output_type, True) |
| emitter.EmitElse() |
| GenerateGemvSwitch1(emitter, output_type, False) |
| emitter.EmitEndif() |
| emitter.EmitFunctionEnd() |
| |
| |
| def GenerateInternalFunctions(emitter): |
| """Generate all the functions hidden in the internal namespace.""" |
| for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: |
| for aligned in [True, False]: |
| for cols in range(0, 8): |
| for leftover in range(0, 8): |
| GenerateGemv(emitter, output_type, aligned, cols, leftover) |
| emitter.EmitNewline() |
| |
| |
| def GeneratePublicFunctions(emitter): |
| for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: |
| GenerateMainGemvFunction(emitter, output_type) |
| emitter.EmitNewline() |