| """Qnt primitive used by the GEMM function. |
| |
| """ |
| |
| import neon_emitter |
| |
| |
| class Error(Exception): |
| """Module level error.""" |
| |
| |
| class ConfigurationError(Error): |
| """Unsupported configuration.""" |
| |
| |
| class QntLane(object): |
| |
| def __init__(self, source, output, offset, load_1, load_2): |
| self.source = source |
| self.output = output |
| self.offset = offset |
| self.load_1 = load_1 |
| self.load_2 = load_2 |
| |
| |
| def BuildName(lanes, leftovers, aligned): |
| name = 'qnt_%dx8' % lanes |
| if leftovers: |
| name += '_%d' % leftovers |
| if aligned: |
| name += '_aligned' |
| return name |
| |
| |
| def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets): |
| if lanes == 1 or lanes == 2 or lanes == 3: |
| offset_registers = [] |
| for unused_i in range(0, lanes): |
| register = registers.QuadRegister() |
| emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)), |
| emitter.AllLanes(registers.High(register))], |
| emitter.DereferenceIncrement(offsets, 32)) |
| offset_registers.append(register) |
| return offset_registers |
| else: |
| raise ConfigurationError('Unsupported number of lanes: %d' % lanes) |
| |
| |
| def GenerateQntLanes(emitter, registers, qnt_lanes, source, stride, destination, |
| destination_stride, offsets): |
| """Prepare lanes for reading unquantized multiplication results.""" |
| offset_registers = LoadAndDuplicateOffsets(emitter, registers, qnt_lanes, |
| offsets) |
| |
| lanes = [] |
| last_input_register = source |
| last_output_register = destination |
| for i in range(0, qnt_lanes): |
| if not i: |
| lanes.append(QntLane(source, |
| destination, |
| offset_registers[i], |
| registers.QuadRegister(), # load 1 |
| registers.QuadRegister())) # load 2 |
| else: |
| input_register = registers.GeneralRegister() |
| output_register = registers.GeneralRegister() |
| lanes.append(QntLane(input_register, |
| output_register, |
| offset_registers[i], |
| registers.QuadRegister(), # load 1 |
| registers.QuadRegister())) # load 2 |
| emitter.EmitAdd(input_register, last_input_register, stride) |
| emitter.EmitAdd(output_register, last_output_register, destination_stride) |
| last_input_register = input_register |
| last_output_register = output_register |
| return lanes |
| |
| |
| def DuplicateRegister(emitter, registers, value): |
| register = registers.QuadRegister() |
| emitter.EmitVDup('32', register, value) |
| return register |
| |
| |
| def GenerateQuantize(emitter, registers, lanes, lane_temps, |
| multiplicative_offset, rounding_offset, shift): |
| """Inner loop for quantization: add offsets, multiply, round, shift.""" |
| for lane in lanes: |
| emitter.EmitVAdd('i32', lane[0], lane[0], lane[1]) |
| |
| for lane in lanes: |
| emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset) |
| |
| for lane in lanes: |
| emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset) |
| |
| for lane in lanes: |
| emitter.EmitVShl('s32', lane[0], lane[0], shift) |
| |
| for lane in lanes: |
| emitter.EmitVQmovn('s32', lane[2], lane[0]) |
| |
| for lane_temp in lane_temps: |
| emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp) |
| |
| |
| def GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset, |
| rounding_offset, shift, alignment): |
| """Load unquantized data from lanes, quantize, store final result.""" |
| lane_temps = [] |
| for lane in lanes: |
| lane_temps.append(registers.QuadRegister()) |
| |
| for lane in lanes: |
| emitter.EmitVLoadA( |
| '1.32', [registers.Low(lane.load_1), registers.High(lane.load_1), |
| registers.Low(lane.load_2), registers.High(lane.load_2)], |
| emitter.DereferenceIncrement(lane.source, 64)) |
| |
| for lane in lanes: |
| emitter.EmitPld(lane.source) |
| |
| quantize_setup = [] |
| for (lane_temp, lane) in zip(lane_temps, lanes): |
| quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) |
| quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)]) |
| |
| GenerateQuantize(emitter, registers, quantize_setup, lane_temps, |
| multiplicative_offset, rounding_offset, shift) |
| |
| for (lane_temp, lane) in zip(lane_temps, lanes): |
| emitter.EmitVStore('1.8', registers.Low(lane_temp), |
| emitter.DereferenceIncrement(lane.output, alignment)) |
| |
| for lane_temp in lane_temps: |
| registers.FreeRegister(lane_temp) |
| |
| |
| def GenerateLoadLeftovers(emitter, registers, leftovers, lanes): |
| """Handle non multiply of 8 leftover loading.""" |
| if leftovers == 1: |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane( |
| registers.Low(lane.load_1), 0), |
| emitter.Dereference(lane.source, None)) |
| elif leftovers == 2: |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', registers.Low(lane.load_1), |
| emitter.Dereference(lane.source, 64)) |
| elif leftovers == 3: |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', registers.Low(lane.load_1), |
| emitter.DereferenceIncrement(lane.source, 64)) |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane( |
| registers.High(lane.load_1), 0), |
| emitter.Dereference(lane.source, None)) |
| elif leftovers == 4: |
| for lane in lanes: |
| emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), |
| registers.High(lane.load_1)], |
| emitter.Dereference(lane.source, 64)) |
| elif leftovers == 5: |
| for lane in lanes: |
| emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), |
| registers.High(lane.load_1)], |
| emitter.DereferenceIncrement(lane.source, 64)) |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane( |
| registers.Low(lane.load_2), 0), |
| emitter.Dereference(lane.source, None)) |
| elif leftovers == 6: |
| for lane in lanes: |
| emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), |
| registers.High(lane.load_1), |
| registers.Low(lane.load_2)], |
| emitter.Dereference(lane.source, 64)) |
| elif leftovers == 7: |
| for lane in lanes: |
| emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), |
| registers.High(lane.load_1), |
| registers.Low(lane.load_2)], |
| emitter.DereferenceIncrement(lane.source, 64)) |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane( |
| registers.High(lane.load_2), 0), |
| emitter.Dereference(lane.source, None)) |
| else: |
| raise ConfigurationError('Unsuported leftover count: %d' % leftovers) |
| |
| |
| def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes): |
| """Handle non multiply of 8 leftover storing.""" |
| setup = [] |
| for (temp, lane) in zip(lane_temps, lanes): |
| setup.append([registers.Low(temp), lane.output]) |
| |
| if leftovers == 1: |
| for lane in setup: |
| emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0), |
| emitter.Dereference(lane[1], None)) |
| elif leftovers == 2: |
| for lane in setup: |
| emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), |
| emitter.Dereference(lane[1], None)) |
| elif leftovers == 3: |
| for lane in setup: |
| emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), |
| emitter.DereferenceIncrement(lane[1], None)) |
| for lane in setup: |
| emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2), |
| emitter.Dereference(lane[1], None)) |
| elif leftovers == 4: |
| for lane in setup: |
| emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), |
| emitter.Dereference(lane[1], None)) |
| elif leftovers == 5: |
| for lane in setup: |
| emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), |
| emitter.DereferenceIncrement(lane[1], None)) |
| for lane in setup: |
| emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4), |
| emitter.Dereference(lane[1], None)) |
| elif leftovers == 6: |
| for lane in setup: |
| emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), |
| emitter.DereferenceIncrement(lane[1], None)) |
| for lane in setup: |
| emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), |
| emitter.Dereference(lane[1], None)) |
| elif leftovers == 7: |
| for lane in setup: |
| emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), |
| emitter.DereferenceIncrement(lane[1], None)) |
| for lane in setup: |
| emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), |
| emitter.DereferenceIncrement(lane[1], None)) |
| for lane in setup: |
| emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6), |
| emitter.DereferenceIncrement(lane[1], None)) |
| else: |
| raise ConfigurationError('Unsupported leftovers count: %d' % leftovers) |
| |
| |
| def GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes, |
| multiplicative_offset, rounding_offset, |
| shift): |
| """Handle leftovers if row size not a multiply of 8.""" |
| lane_temps = [] |
| for lane in lanes: |
| lane_temps.append(registers.QuadRegister()) |
| |
| GenerateLoadLeftovers(emitter, registers, leftovers, lanes) |
| |
| quantize_setup = [] |
| for (lane_temp, lane) in zip(lane_temps, lanes): |
| quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) |
| if leftovers > 4: |
| quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp) |
| ]) |
| |
| GenerateQuantize(emitter, registers, quantize_setup, lane_temps, |
| multiplicative_offset, rounding_offset, shift) |
| |
| GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes) |
| |
| |
| def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned): |
| """Emits optimized quantization code for given lanes and row size.""" |
| if leftovers < 0 or leftovers > 7: |
| raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') |
| if qnt_lanes < 1 or qnt_lanes > 3: |
| raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.') |
| |
| name = BuildName(qnt_lanes, leftovers, aligned) |
| |
| emitter.EmitFunctionBeginA( |
| name, |
| [['const std::int32_t*', 'source'], ['std::int32_t', 'count'], |
| ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'], |
| ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'], |
| ['std::int32_t', 'multiplicative_offset'], |
| ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void') |
| emitter.EmitAssert('count %% 8 == %d' % leftovers) |
| emitter.EmitAssert('count >= 8') |
| emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0') |
| if aligned: |
| emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0') |
| if qnt_lanes > 1: |
| emitter.EmitAssert('destination_stride % 8 == 0') |
| emitter.EmitAsmBegin() |
| |
| registers = neon_emitter.NeonRegisters() |
| |
| count = registers.MapParameter('count') |
| |
| multiplicative_offset = DuplicateRegister( |
| emitter, registers, registers.MapParameter('multiplicative_offset')) |
| rounding_offset = DuplicateRegister(emitter, registers, |
| registers.MapParameter('rounding_offset')) |
| shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift')) |
| |
| lanes = GenerateQntLanes( |
| emitter, registers, qnt_lanes, registers.MapParameter('source'), |
| registers.MapParameter('stride'), registers.MapParameter('destination'), |
| registers.MapParameter('destination_stride'), |
| registers.MapParameter('offsets')) |
| |
| if leftovers: |
| emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers)) |
| emitter.EmitBeqFront(2) |
| |
| emitter.EmitNewline() |
| emitter.EmitNumericalLabel(1) |
| emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) |
| |
| GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset, |
| rounding_offset, shift, 64 if aligned else None) |
| |
| emitter.EmitNewline() |
| emitter.EmitBneBack(1) |
| |
| if leftovers: |
| emitter.EmitNumericalLabel(2) |
| GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes, |
| multiplicative_offset, rounding_offset, |
| shift) |
| |
| emitter.EmitAsmEnd(registers.MappedParameters(), [], |
| registers.Clobbers() + ['cc', 'memory']) |
| emitter.EmitFunctionEnd() |
| |
| |
| def BuildMultiQuantizeName(aligned, rows): |
| name = 'multi_qnt_%dx8' % rows |
| if aligned: |
| name = '%s_aligned' % name |
| return name |
| |
| |
| def GenerateMultiQuantize(emitter, aligned, rows): |
| """Emit main quantization code that switches between optimized versions.""" |
| name = BuildMultiQuantizeName(aligned, rows) |
| emitter.EmitFunctionBeginA( |
| name, |
| [['const std::int32_t*', 'source'], ['std::int32_t', 'count'], |
| ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'], |
| ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'], |
| ['std::int32_t', 'multiplicative_offset'], |
| ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void') |
| emitter.EmitSwitch('count % 8') |
| |
| for leftovers in range(0, 8): |
| emitter.EmitCase(leftovers) |
| emitter.PushIndent() |
| emitter.EmitCall( |
| BuildName(rows, leftovers, aligned), |
| ['source', 'count', 'stride', 'offsets', 'destination', |
| 'destination_stride', 'multiplicative_offset', 'rounding_offset', |
| 'shift']) |
| emitter.EmitBreak() |
| emitter.PopIndent() |
| |
| emitter.EmitSwitchEnd() |
| emitter.EmitFunctionEnd() |
| |
| |
| def GenerateFunctions(neon, cc): |
| for aligned in [True, False]: |
| for lanes in range(1, 4): |
| for leftovers in range(0, 8): |
| GenerateQntNx8(neon, lanes, leftovers, aligned) |
| neon.EmitNewline() |
| |
| for aligned in [True, False]: |
| for rows in range(1, 4): |
| GenerateMultiQuantize(cc, aligned, rows) |
| cc.EmitNewline() |