| """Mul primitive used by the GEMM function. |
| |
| The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs |
| matrix multiplication on those resulting in a small 1x1 to 3x3 block of results. |
| """ |
| |
| import neon_emitter |
| |
| |
| class Error(Exception): |
| """Module level error.""" |
| |
| |
| class ConfigurationError(Error): |
| """Unsupported configuration.""" |
| |
| |
| class MulLanes(object): |
| |
| def __init__(self, input_address): |
| self.input_address = input_address |
| self.lanes = [] |
| |
| def AddLane(self, lane): |
| self.lanes.append(lane) |
| |
| def FreeRegisters(self, registers): |
| for i in range(0, len(self.lanes)): |
| registers.FreeRegister(self.lanes[i]) |
| self.lanes[i] = None |
| |
| |
| def GenerateMulLanes(registers, lane_count, address): |
| lanes = MulLanes(address) |
| for unused_i in range(0, lane_count): |
| lanes.AddLane(registers.DoubleRegister()) |
| return lanes |
| |
| |
| def Generate3MulLanes(quad_register, registers, address): |
| lanes = MulLanes(address) |
| lanes.AddLane(registers.Low(quad_register)) |
| lanes.AddLane(registers.High(quad_register)) |
| lanes.AddLane(registers.DoubleRegister()) |
| return lanes |
| |
| |
| def GenerateAndClearAggregators(emitter, registers, aggregator_count): |
| """Prepare aggregators and emit aggregator clear code.""" |
| emitter.EmitComment('Clear aggregators.') |
| aggregators = [] |
| for i in range(0, aggregator_count): |
| aggregator = registers.QuadRegister() |
| aggregators.append(aggregator) |
| if i < 3: |
| emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0)) |
| else: |
| emitter.EmitVMov('i32', aggregator, aggregators[i - 3]) |
| emitter.EmitNewline() |
| return aggregators |
| |
| |
| def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, |
| right_lanes, aggregators, count): |
| """Emit inner loop for N rows x M cols multiplication.""" |
| emitter.EmitComment('General NxM lanes loop.') |
| emitter.EmitNumericalLabel(1) |
| emitter.EmitNewline() |
| emitter.EmitComment('Subtract counter.') |
| emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) |
| emitter.EmitNewline() |
| |
| emitter.EmitVLoadA('1.8', left_lanes.lanes, |
| emitter.DereferenceIncrement(left_lanes.input_address, 64)) |
| emitter.EmitVLoadA( |
| '1.8', right_lanes.lanes, |
| emitter.DereferenceIncrement(right_lanes.input_address, 64)) |
| |
| emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) |
| emitter.EmitPldOffset(right_lanes.input_address, |
| emitter.ImmediateConstant(64)) |
| |
| rows = len(left_lanes.lanes) |
| cols = len(right_lanes.lanes) |
| |
| multiply_results = [] |
| for i in range(0, rows * cols): |
| multiply_results.append(registers.QuadRegister()) |
| |
| for row in range(0, rows): |
| for col in range(0, cols): |
| index = row * cols + col |
| emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col], |
| left_lanes.lanes[row]) |
| |
| for i in range(0, rows * cols): |
| emitter.EmitVPadal('u16', aggregators[i], multiply_results[i]) |
| |
| emitter.EmitNewline() |
| emitter.EmitComment('Loop break.') |
| emitter.EmitBneBack(1) |
| emitter.EmitNewline() |
| |
| for register in multiply_results: |
| registers.FreeRegister(register) |
| |
| |
| def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, |
| right_lanes, aggregators, count, |
| backup_register): |
| """Emit inner loop for 3 rows x 3 cols multiplication (register trick).""" |
| emitter.EmitComment('3x3 lanes loop.') |
| emitter.EmitNumericalLabel(1) |
| emitter.EmitNewline() |
| emitter.EmitComment('Subtract counter.') |
| emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) |
| emitter.EmitNewline() |
| |
| emitter.EmitVLoadA('1.8', left_lanes.lanes, |
| emitter.DereferenceIncrement(left_lanes.input_address, 64)) |
| emitter.EmitVLoadA( |
| '1.8', right_lanes.lanes, |
| emitter.DereferenceIncrement(right_lanes.input_address, 64)) |
| |
| emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) |
| emitter.EmitPldOffset(right_lanes.input_address, |
| emitter.ImmediateConstant(64)) |
| |
| temp = [] |
| for unused_i in range(0, 4): |
| temp.append(registers.QuadRegister()) |
| |
| emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0]) |
| emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1]) |
| emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2]) |
| emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0]) |
| |
| emitter.EmitVPadal('u16', aggregators[0], temp[0]) |
| emitter.EmitVPadal('u16', aggregators[1], temp[1]) |
| emitter.EmitVPadal('u16', aggregators[2], temp[2]) |
| emitter.EmitVPadal('u16', aggregators[3], temp[3]) |
| |
| emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1]) |
| emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2]) |
| emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0]) |
| emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1]) |
| emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2], |
| right_lanes.lanes[2]) |
| |
| emitter.EmitVPadal('u16', aggregators[4], temp[0]) |
| emitter.EmitVPadal('u16', aggregators[5], temp[1]) |
| emitter.EmitVPadal('u16', aggregators[6], temp[2]) |
| emitter.EmitVPadal('u16', aggregators[7], temp[3]) |
| emitter.EmitVPadal('u16', aggregators[8], backup_register) |
| |
| emitter.EmitNewline() |
| emitter.EmitComment('Loop break.') |
| emitter.EmitBneBack(1) |
| emitter.EmitNewline() |
| |
| for register in temp: |
| registers.FreeRegister(register) |
| |
| |
| def ReadParams(emitter, registers, input_address, elements, min_reg): |
| if elements == 1 or elements == 2: |
| register = registers.DoubleRegister(min_reg * 2) |
| emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) |
| return register |
| elif elements == 3 or elements == 4: |
| register = registers.QuadRegister(min_reg) |
| emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) |
| return register |
| else: |
| raise ConfigurationError('Unsupported elements no: %d' % elements) |
| |
| |
| def Duplicate(emitter, registers, rows, cols, min_register, values): |
| """Populate a grid of registers duplicating provided values.""" |
| duplicated = [] |
| if cols == 1 or cols == 2: |
| for unused_i in range(0, rows): |
| duplicated.append(registers.DoubleRegister(min_register)) |
| elif cols == 3 or cols == 4: |
| for unused_i in range(0, rows): |
| duplicated.append(registers.QuadRegister(min_register)) |
| else: |
| raise ConfigurationError('Unsupported duplicate amount: %d' % cols) |
| |
| if rows == 1: |
| emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) |
| elif rows == 2: |
| emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) |
| emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1)) |
| elif rows == 3: |
| emitter.EmitVDup('32', duplicated[0], emitter.Lane( |
| registers.Low(values), 0)) |
| emitter.EmitVDup('32', duplicated[1], emitter.Lane( |
| registers.Low(values), 1)) |
| emitter.EmitVDup('32', duplicated[2], emitter.Lane( |
| registers.High(values), 0)) |
| elif rows == 4: |
| emitter.EmitVDup('32', duplicated[0], emitter.Lane( |
| registers.Low(values), 0)) |
| emitter.EmitVDup('32', duplicated[1], emitter.Lane( |
| registers.Low(values), 1)) |
| emitter.EmitVDup('32', duplicated[2], emitter.Lane( |
| registers.High(values), 0)) |
| emitter.EmitVDup('32', duplicated[3], emitter.Lane( |
| registers.High(values), 1)) |
| |
| return duplicated |
| |
| |
| def DuplicateGeneralRegister(emitter, registers, cols, general_register, |
| min_register): |
| if cols == 1 or cols == 2: |
| duplicated = registers.DoubleRegister(min_register) |
| elif cols == 3 or cols == 4: |
| duplicated = registers.QuadRegister(min_register) |
| else: |
| raise ConfigurationError('Unsupported duplicate amount: %d' % cols) |
| |
| emitter.EmitVDup('32', duplicated, general_register) |
| return duplicated |
| |
| |
| def ReduceAggregator(emitter, registers, aggregators, row, cols): |
| if cols == 1: |
| register = registers.Low(aggregators[row]) |
| emitter.EmitVPadd('u32', register, register, register) |
| return register |
| elif cols == 2: |
| register = registers.Low(aggregators[row * 2]) |
| emitter.EmitVPadd('u32', register, register, |
| registers.Low(aggregators[row * 2 + 1])) |
| return register |
| elif cols == 3: |
| register = aggregators[row * 3] |
| emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register), |
| registers.Low(aggregators[row * 3 + 1])) |
| emitter.EmitVPadd('u32', registers.High(register), |
| registers.Low(aggregators[row * 3 + 2]), |
| registers.Low(aggregators[row * 3 + 2])) |
| return register |
| elif cols == 4: |
| register = aggregators[row * 3] |
| emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register), |
| registers.Low(aggregators[row * 3 + 1])) |
| emitter.EmitVPadd('u32', registers.High(register), |
| registers.Low(aggregators[row * 3 + 2]), |
| registers.Low(aggregators[row * 3 + 3])) |
| return register |
| else: |
| raise ConfigurationError('Unsupported columns no: %d' % cols) |
| |
| |
| def StoreAggregator(emitter, registers, aggregator, cols, result_address, |
| result_stride): |
| if cols == 1: |
| emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0), |
| emitter.Dereference(result_address, None), |
| result_stride) |
| elif cols == 2: |
| emitter.EmitVStoreOffset('1.32', aggregator, |
| emitter.Dereference(result_address, None), |
| result_stride) |
| elif cols == 3: |
| emitter.EmitVStore('1.32', registers.Low(aggregator), |
| emitter.DereferenceIncrement(result_address, None)) |
| emitter.EmitVStoreOffset('1.32', emitter.Lane( |
| registers.High(aggregator), |
| 0), emitter.Dereference(result_address, None), result_stride) |
| emitter.EmitNewline() |
| elif cols == 4: |
| emitter.EmitVStoreOffsetA( |
| '1.32', [registers.Low(aggregator), registers.High(aggregator)], |
| emitter.Dereference(result_address, None), result_stride) |
| else: |
| raise ConfigurationError('Unsupported columns no: %d' % cols) |
| |
| |
| def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, |
| lhs_add, rhs_add, left_lanes, right_lanes, |
| results, results_stride): |
| """Emit code that reduces 4 lane aggregators to 1 value, and stores them.""" |
| rows = len(left_lanes.lanes) |
| cols = len(right_lanes.lanes) |
| |
| if lhs_add: |
| left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows, |
| 4) |
| left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset) |
| else: |
| left_offsets = None |
| |
| if rhs_add: |
| right_offset = ReadParams(emitter, registers, right_lanes.input_address, |
| cols, 4) |
| else: |
| right_offset = None |
| |
| if result_type is 'float': |
| result_scale = DuplicateGeneralRegister( |
| emitter, registers, cols, registers.MapParameter('result_scale'), 4) |
| else: |
| result_scale = None |
| |
| if cols == 3: |
| emitter.EmitNewline() |
| emitter.EmitComment('Change stride because storing in two ops.') |
| emitter.EmitSub(results_stride, results_stride, |
| emitter.ImmediateConstant(8)) |
| |
| emitter.EmitNewline() |
| emitter.EmitComment('Horizontal reduce aggregators.') |
| for aggregator in aggregators: |
| emitter.EmitVPadd('u32', registers.Low(aggregator), |
| registers.Low(aggregator), registers.High(aggregator)) |
| |
| emitter.EmitNewline() |
| emitter.EmitComment('Reduce rows.') |
| row_temps = [] |
| for i in range(0, rows): |
| row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols)) |
| |
| if lhs_add: |
| emitter.EmitNewline() |
| emitter.EmitComment('Add lhs offsets to aggregated rows.') |
| for (row_temp, left_offset) in zip(row_temps, left_offsets): |
| emitter.EmitVAdd('s32', row_temp, row_temp, left_offset) |
| |
| if rhs_add: |
| emitter.EmitNewline() |
| emitter.EmitComment('Add rhs offset to aggregated rows.') |
| for row_temp in row_temps: |
| emitter.EmitVAdd('s32', row_temp, row_temp, right_offset) |
| |
| if result_type is 'float': |
| emitter.EmitNewline() |
| emitter.EmitComment('Convert to float. Multiply by result scale.') |
| for row_temp in row_temps: |
| emitter.EmitVCvt('f32', 's32', row_temp, row_temp) |
| for row_temp in row_temps: |
| emitter.EmitVMul('f32', row_temp, row_temp, result_scale) |
| |
| emitter.EmitNewline() |
| emitter.EmitComment('Store reduced rows.') |
| for row_temp in row_temps: |
| StoreAggregator(emitter, registers, row_temp, cols, results, results_stride) |
| |
| |
| def BuildName(result_type, lhs_add, rhs_add, left, right): |
| name = 'mul_%dx8_%dx8_%s' % (left, right, result_type) |
| if lhs_add: |
| name += '_lhsadd' |
| if rhs_add: |
| name += '_rhsadd' |
| return name |
| |
| |
| def CppResultType(result_type): |
| if result_type is 'int32': |
| return 'std::int32_t*' |
| elif result_type is 'float': |
| return 'float*' |
| else: |
| raise ConfigurationError('Unsupported result type: %s' % result_type) |
| |
| |
| def GetParameters(result_type): |
| params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'], |
| ['std::int32_t', 'count'], [CppResultType(result_type), 'result'], |
| ['std::int32_t', 'result_stride']] |
| if result_type is 'float': |
| params.append(['float', 'result_scale']) |
| return params |
| |
| |
| def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count, |
| right_lanes_count): |
| """Emit the multiply code for given rows and cols counts.""" |
| if left_lanes_count < 1 or left_lanes_count > 4: |
| raise ConfigurationError('Left_lanes should be: 1, 2, 3 or 4.') |
| if right_lanes_count < 1 or right_lanes_count > 4: |
| raise ConfigurationError('Right_lanes should be: 1, 2, 3 or 4.') |
| |
| emitter.EmitFunctionBeginA( |
| BuildName(result_type, lhs_add, rhs_add, left_lanes_count, |
| right_lanes_count), GetParameters(result_type), 'inline void') |
| |
| emitter.EmitAssert('count % 8 == 0') |
| emitter.EmitAssert('count >= 8') |
| emitter.EmitAsmBegin() |
| |
| registers = neon_emitter.NeonRegisters() |
| |
| count = registers.MapParameter('count') |
| |
| size = left_lanes_count * right_lanes_count |
| |
| lhs = registers.MapParameter('lhs') |
| rhs = registers.MapParameter('rhs') |
| |
| emitter.EmitPld(lhs) |
| emitter.EmitPld(rhs) |
| |
| aggregators = GenerateAndClearAggregators(emitter, registers, size) |
| |
| if size < 9: |
| left_lanes = GenerateMulLanes(registers, left_lanes_count, lhs) |
| right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs) |
| |
| GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, |
| right_lanes, aggregators, count) |
| |
| else: # left == 3 and right == 3 |
| backup_register = registers.QuadRegister() |
| left_lanes = Generate3MulLanes(backup_register, registers, lhs) |
| right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs) |
| |
| Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, |
| right_lanes, aggregators, count, |
| backup_register) |
| left_lanes.FreeRegisters(registers) |
| right_lanes.FreeRegisters(registers) |
| |
| GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, |
| lhs_add, rhs_add, left_lanes, right_lanes, |
| registers.MapParameter('result'), |
| registers.MapParameter('result_stride')) |
| |
| emitter.EmitAsmEnd(registers.MappedParameters(), [], |
| registers.Clobbers() + ['cc', 'memory']) |
| emitter.EmitFunctionEnd() |
| |
| |
| def GenerateFunctions(emitter, result_type, lhs_add, rhs_add): |
| for left_lanes in range(1, 4): |
| for right_lanes in range(1, 4): |
| GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes, |
| right_lanes) |
| emitter.EmitNewline() |
| |
| GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, 1, 4) |
| emitter.EmitNewline() |
| |
| |
| if __name__ == '__main__': |
| GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True) |