| # Copyright 2016 The Gemmlowp Authors. All rights reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """.""" |
| |
| import common |
| |
| |
| def _AlignForLanes(lanes_count): |
| if lanes_count is 8 or lanes_count is 4: |
| return 256 |
| elif lanes_count is 6 or lanes_count is 2: |
| return 128 |
| else: |
| return 64 |
| |
| |
| def _AlignForSums(lanes_count): |
| if lanes_count is 8: |
| return 256 |
| elif lanes_count in [2, 4, 6]: |
| return 128 |
| else: |
| return 64 |
| |
| |
| def _GenerateInputs(emitter, registers, lanes_count, input_address, stride): |
| """.""" |
| inputs = [] |
| last_address_register = input_address |
| for i in range(lanes_count): |
| if not i: |
| inputs.append(input_address) |
| else: |
| address_register = registers.GeneralRegister() |
| inputs.append(address_register) |
| emitter.EmitAdd(address_register, last_address_register, stride) |
| last_address_register = address_register |
| return inputs |
| |
| |
| def _GenerateClear(emitter, clear_type, block): |
| for row in block: |
| emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0)) |
| |
| |
| def _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count, |
| aggregators, inputs, output): |
| """Emit inner loop code for reading N lanes and interweaving them.""" |
| emitter.EmitNewline() |
| emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count, |
| elements_count)) |
| |
| block = [registers.DoubleRegister() for unused_i in range(lanes_count)] |
| |
| if elements_count is not 8: |
| _GenerateClear(emitter, 'i8', block) |
| |
| for (row, input_address) in zip(block, inputs): |
| emitter.EmitVLoadE(8, elements_count, row, input_address, None) |
| |
| for (aggregator, row) in zip(aggregators, block): |
| emitter.EmitVAddw('u8', aggregator, aggregator, row) |
| |
| emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, |
| _AlignForLanes(lanes_count)) |
| |
| registers.FreeRegisters(block) |
| |
| |
| def _LoadMemoryParameter(emitter, registers, name, source): |
| register = registers.GeneralRegister() |
| emitter.EmitLdr(register, registers.MapMemoryParameter(name, source)) |
| return register |
| |
| |
| def _GenerateAggregatorReductionLowRegisters(emitter, registers, |
| aggregators, output_address): |
| emitter.EmitNewline() |
| emitter.EmitComment('Aggregator Reduction.') |
| _GenerateAggregatorReduction( |
| emitter, registers, aggregators, output_address, |
| _LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset', |
| 'params.multiplicative_sum_offset'), |
| _LoadMemoryParameter(emitter, registers, 'additive_sum_offset', |
| 'params.additive_sum_offset')) |
| |
| |
| def _GenerateAggregatorReductionHighRegisters(emitter, registers, |
| aggregators, output_address): |
| emitter.EmitNewline() |
| emitter.EmitComment('Aggregator Reduction.') |
| _GenerateAggregatorReduction( |
| emitter, registers, aggregators, output_address, |
| registers.MapParameter('multiplicative_sum_offset', |
| 'params.multiplicative_sum_offset'), |
| registers.MapParameter('additive_sum_offset', |
| 'params.additive_sum_offset')) |
| |
| |
| def _GenerateAggregatorReduction(emitter, registers, aggregators, |
| output_address, multiplicative_sum_offset, |
| additive_sum_offset): |
| """Reduce 4 lane sum aggregators to 1 value and store the sums.""" |
| multiplier = registers.DoubleRegister() |
| emitter.EmitVMov('32', |
| emitter.Lane(32, multiplier, 0), multiplicative_sum_offset) |
| |
| offset = registers.QuadRegister() |
| emitter.EmitVDup('32', offset, additive_sum_offset) |
| |
| for aggregator in aggregators: |
| emitter.EmitVPaddl('u16', aggregator, aggregator) |
| |
| reduced_count = (len(aggregators) + 3) / 4 |
| reduced = aggregators[:reduced_count] |
| |
| emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators) |
| |
| for temp in reduced: |
| emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0)) |
| |
| for temp in reduced: |
| emitter.EmitVAdd('i32', temp, temp, offset) |
| |
| emitter.EmitVStoreA(1, 32, reduced, |
| emitter.Dereference(output_address, |
| _AlignForSums(len(aggregators)))) |
| |
| |
| class RowMajorWithSumUInt8x8(common.StreamGenerator): |
| """.""" |
| |
| def __init__(self, emitter, asm_emitter): |
| common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum') |
| self.asm_emitter = asm_emitter |
| |
| def EmitPack(self, in_type, lanes_count, pack_size, leftovers): |
| assert pack_size is 8 |
| assert in_type is 'uint8_t' |
| |
| registers = self.asm_emitter.CreateRegisters() |
| |
| self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') |
| |
| self.asm_emitter.PushIndent(self.emitter.indent) |
| self.asm_emitter.EmitAsmBegin() |
| |
| count = registers.MapOutputParameter('count', 'params_count_copy') |
| output = registers.MapOutputParameter('out') |
| inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count, |
| registers.MapOutputParameter('in'), |
| registers.MapParameter('stride', 'params.stride')) |
| aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] |
| |
| _GenerateClear(self.asm_emitter, 'i16', aggregators) |
| |
| if leftovers: |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitComment('Reduce count by leftovers.') |
| self.asm_emitter.EmitSubs(count, count, |
| self.asm_emitter.ImmediateConstant(leftovers)) |
| self.asm_emitter.EmitBeqFront(2) |
| |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitNumericalLabel(1) |
| self.asm_emitter.EmitSubs(count, count, |
| self.asm_emitter.ImmediateConstant(8)) |
| |
| _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, |
| aggregators, inputs, output) |
| |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitBneBack(1) |
| |
| if leftovers: |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitNumericalLabel(2) |
| _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, |
| leftovers, aggregators, inputs, output) |
| |
| registers.FreeRegisters(inputs) |
| |
| if len(inputs) <= 6: |
| _GenerateAggregatorReductionHighRegisters( |
| self.asm_emitter, registers, aggregators, output) |
| else: |
| _GenerateAggregatorReductionLowRegisters( |
| self.asm_emitter, registers, aggregators, output) |
| |
| self.asm_emitter.EmitAsmEnd(registers) |
| self.asm_emitter.PopIndent(len(self.emitter.indent)) |
| |
| |
| def _GenerateColLoadAggregateStore(emitter, registers, lanes_count, |
| elements_count, aggregators, input_address, |
| stride, output): |
| """Emit inner loop code for reading N col lanes and interweaving them.""" |
| emitter.EmitNewline() |
| emitter.EmitComment('Load Aggregate Store - column major %dx%d' % |
| (lanes_count, elements_count)) |
| |
| block = [registers.DoubleRegister() for unused_i in range(lanes_count)] |
| |
| if elements_count is not 8: |
| _GenerateClear(emitter, 'i8', block) |
| |
| block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count, |
| block, input_address, stride) |
| |
| for (aggregator, row) in zip(aggregators, block): |
| emitter.EmitVAddw('u8', aggregator, aggregator, row) |
| |
| emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, |
| _AlignForLanes(lanes_count)) |
| |
| registers.FreeRegisters(block) |
| |
| |
| class ColumnMajorWithSumUInt8x8(common.StreamGenerator): |
| """.""" |
| |
| def __init__(self, emitter, asm_emitter): |
| common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum') |
| self.asm_emitter = asm_emitter |
| |
| def EmitPack(self, in_type, lanes_count, pack_size, leftovers): |
| assert pack_size is 8 |
| assert in_type is 'uint8_t' |
| |
| registers = self.asm_emitter.CreateRegisters() |
| |
| self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') |
| self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride') |
| |
| self.asm_emitter.PushIndent(self.emitter.indent) |
| self.asm_emitter.EmitAsmBegin() |
| |
| count = registers.MapOutputParameter('count', 'params_count_copy') |
| input_address = registers.MapOutputParameter('in') |
| output_address = registers.MapOutputParameter('out') |
| aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] |
| stride = registers.MapOutputParameter('stride', 'params_stride_copy') |
| |
| self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride) |
| |
| _GenerateClear(self.asm_emitter, 'i16', aggregators) |
| |
| if leftovers: |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitComment('Reduce count by leftovers.') |
| self.asm_emitter.EmitSubs(count, count, |
| self.asm_emitter.ImmediateConstant(leftovers)) |
| self.asm_emitter.EmitBeqFront(2) |
| |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitNumericalLabel(1) |
| self.asm_emitter.EmitSubs(count, count, |
| self.asm_emitter.ImmediateConstant(8)) |
| |
| _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, |
| aggregators, input_address, stride, |
| output_address) |
| |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitBneBack(1) |
| |
| if leftovers: |
| self.asm_emitter.EmitNewline() |
| self.asm_emitter.EmitNumericalLabel(2) |
| _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, |
| leftovers, aggregators, input_address, |
| stride, output_address) |
| |
| |
| _GenerateAggregatorReductionHighRegisters( |
| self.asm_emitter, registers, aggregators, output_address) |
| |
| self.asm_emitter.EmitAsmEnd(registers) |
| self.asm_emitter.PopIndent(len(self.emitter.indent)) |
| |
| |
| def GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count): |
| row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter) |
| column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter) |
| |
| for lanes_count in range(1, 1 + lanes_count): |
| for leftovers in range(8): |
| row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers) |
| |
| for lanes_count in range(1, 1 + lanes_count): |
| for leftovers in range(8): |
| column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, |
| leftovers) |