| """Zip primitive used by the GEMM function. |
| |
| Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to |
| multiply of 8 length with zeros. Calculates row sums and appends those at the |
| end. |
| """ |
| |
| import neon_emitter |
| |
| |
| class Error(Exception): |
| """Module level error.""" |
| |
| |
| class ConfigurationError(Error): |
| """Unsupported configuration.""" |
| |
| |
| class ZipLane(object): |
| |
| def __init__(self, input_address, load, aggregator): |
| self.input_address = input_address |
| self.load = load |
| self.aggregator = aggregator |
| |
| |
| def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride): |
| """Prepares read lanes for the zip operation. |
| |
| Args: |
| emitter: ARM/NEON emitter. |
| registers: ARM/NEON registers state. |
| zip_lanes: number of lanes to prepare. |
| input_address: register that contains the input address for the first lane. |
| stride: memory stride for lane inputs. |
| |
| Returns: |
| Array of ZipLane objects. |
| """ |
| lanes = [] |
| last_address_register = input_address |
| for i in range(0, zip_lanes): |
| if not i: |
| lanes.append(ZipLane(input_address, registers.DoubleRegister(), |
| registers.QuadRegister(2))) |
| else: |
| address_register = registers.GeneralRegister() |
| lanes.append(ZipLane(address_register, registers.DoubleRegister(), |
| registers.QuadRegister(2))) |
| emitter.EmitAdd(address_register, last_address_register, stride) |
| last_address_register = address_register |
| return lanes |
| |
| |
| def BuildName(zip_lanes, leftovers, aligned): |
| name = 'zip_%dx8' % zip_lanes |
| if leftovers: |
| name += '_%d' % leftovers |
| if aligned: |
| name += '_aligned' |
| return name |
| |
| |
| def GenerateClearAggregators(emitter, lanes): |
| for lane in lanes: |
| emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0)) |
| |
| |
| def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment): |
| """Emit inner loop code for reading N lanes and interweaving them.""" |
| emitter.EmitNewline() |
| emitter.EmitComment('Load Aggregate Store.') |
| |
| for lane in lanes: |
| emitter.EmitVLoad( |
| '1.8', lane.load, |
| emitter.DereferenceIncrement(lane.input_address, alignment)) |
| |
| store_registers = [] |
| for lane in lanes: |
| emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) |
| store_registers.append(lane.load) |
| |
| emitter.EmitVStoreA('1.8', store_registers, |
| emitter.DereferenceIncrement(output_address, 64)) |
| |
| |
| def GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes, |
| output_address): |
| """Handle leftovers when count is not a multiply of 8.""" |
| emitter.EmitNewline() |
| emitter.EmitComment('Leftover Load Aggregate Store.') |
| |
| # Clear load registers. |
| for lane in lanes: |
| emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0)) |
| |
| if leftovers == 1: |
| # Load 8 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0), |
| emitter.Dereference(lane.input_address, None)) |
| elif leftovers == 2: |
| # Load 16 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0), |
| emitter.Dereference(lane.input_address, None)) |
| elif leftovers == 3: |
| # Load 16 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0), |
| emitter.DereferenceIncrement(lane.input_address, None)) |
| # Load 8 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2), |
| emitter.Dereference(lane.input_address, None)) |
| elif leftovers == 4: |
| # Load 32 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), |
| emitter.Dereference(lane.input_address, None)) |
| elif leftovers == 5: |
| # Load 32 bits.. |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), |
| emitter.DereferenceIncrement(lane.input_address, None)) |
| # Load 8 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4), |
| emitter.Dereference(lane.input_address, None)) |
| elif leftovers == 6: |
| # Load 32 bits.. |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), |
| emitter.DereferenceIncrement(lane.input_address, None)) |
| # Load 16 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2), |
| emitter.Dereference(lane.input_address, None)) |
| elif leftovers == 7: |
| # Load 32 bits.. |
| for lane in lanes: |
| emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), |
| emitter.DereferenceIncrement(lane.input_address, None)) |
| # Load 16 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2), |
| emitter.DereferenceIncrement(lane.input_address, None)) |
| # Load 8 bits. |
| for lane in lanes: |
| emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6), |
| emitter.Dereference(lane.input_address, None)) |
| else: |
| raise ConfigurationError('Unsupported leftover num: %d' % leftovers) |
| |
| # Aggregate. |
| store_registers = [] |
| for lane in lanes: |
| emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) |
| store_registers.append(lane.load) |
| |
| # Store. |
| emitter.EmitVStoreA('1.8', store_registers, |
| emitter.DereferenceIncrement(output_address, 64)) |
| |
| |
| def GenerateAggregatorReduction(emitter, registers, lanes, output_address, |
| multiplicative_offset, additive_offset): |
| """Reduce 4 lane sum aggregators to 1 value and store the sums.""" |
| emitter.EmitNewline() |
| emitter.EmitComment('Aggregator Reduction.') |
| |
| multiplier = registers.DoubleRegister() |
| emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset) |
| offset = registers.QuadRegister() |
| emitter.EmitVDup('32', offset, additive_offset) |
| |
| lane_temps = [] |
| for lane in lanes: |
| emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator) |
| |
| for lane in lanes: |
| lane_temp = registers.DoubleRegister() |
| lane_temps.append(lane_temp) |
| emitter.EmitVPadd('u32', lane_temp, registers.Low(lane.aggregator), |
| registers.High(lane.aggregator)) |
| |
| temp = registers.QuadRegister() |
| low = registers.Low(temp) |
| high = registers.High(temp) |
| |
| if len(lanes) == 1: |
| emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0]) |
| elif len(lanes) == 2: |
| emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) |
| elif len(lanes) == 3: |
| emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) |
| emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2]) |
| elif len(lanes) == 4: |
| emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) |
| emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3]) |
| else: |
| raise ConfigurationError('Unexpected number of aggregators to reduce: %d' % |
| len(lanes)) |
| |
| emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0)) |
| emitter.EmitVAdd('i32', temp, temp, offset) |
| |
| if len(lanes) == 1: |
| emitter.EmitVStore('1.32', emitter.Lane(low, 0), |
| emitter.Dereference(output_address, None)) |
| elif len(lanes) == 2: |
| emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64)) |
| elif len(lanes) == 3: |
| emitter.EmitVStore('1.32', low, |
| emitter.DereferenceIncrement(output_address, 64)) |
| emitter.EmitVStore('1.32', emitter.Lane(high, 0), |
| emitter.Dereference(output_address, None)) |
| elif len(lanes) == 4: |
| emitter.EmitVStoreA('1.32', [low, high], |
| emitter.DereferenceIncrement(output_address, 64)) |
| |
| |
| def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned): |
| """Emit the zip function for a given number of rows and row size leftovers.""" |
| if leftovers < 0 or leftovers > 7: |
| raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') |
| if zip_lanes < 1 or zip_lanes > 4: |
| raise ConfigurationError('Zip_lanes should should be 1, 2, 3 or 4.') |
| |
| name = BuildName(zip_lanes, leftovers, aligned) |
| |
| emitter.EmitFunctionBeginA( |
| name, [['const std::uint8_t*', 'source'], ['std::int32_t', 'count'], |
| ['std::int32_t', 'stride'], ['std::uint8_t*', 'destination'], |
| ['std::int32_t', 'multiplicative_offset'], |
| ['std::int32_t', 'additive_offset']], 'void') |
| emitter.EmitAssert('count %% 8 == %d' % leftovers) |
| emitter.EmitAssert('count <= 2048') |
| emitter.EmitAssert('count >= 8') |
| emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0') |
| if aligned: |
| emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0') |
| if zip_lanes > 1: |
| emitter.EmitAssert('stride % 8 == 0') |
| emitter.EmitAsmBegin() |
| |
| registers = neon_emitter.NeonRegisters() |
| |
| count = registers.MapParameter('count') |
| output_address = registers.MapParameter('destination') |
| |
| lanes = GenerateZipLanes(emitter, registers, zip_lanes, |
| registers.MapParameter('source'), |
| registers.MapParameter('stride')) |
| |
| if leftovers: |
| emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers)) |
| |
| GenerateClearAggregators(emitter, lanes) |
| |
| emitter.EmitNewline() |
| emitter.EmitNumericalLabel(1) |
| emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) |
| |
| GenerateLoadAggregateStore(emitter, lanes, output_address, 64 if aligned else |
| None) |
| |
| emitter.EmitNewline() |
| emitter.EmitBneBack(1) |
| |
| if leftovers: |
| GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes, |
| output_address) |
| |
| GenerateAggregatorReduction(emitter, registers, lanes, output_address, |
| registers.MapParameter('multiplicative_offset'), |
| registers.MapParameter('additive_offset')) |
| |
| emitter.EmitAsmEnd(registers.MappedParameters(), [], |
| registers.Clobbers() + ['cc', 'memory']) |
| emitter.EmitFunctionEnd() |
| |
| |
| def GenerateFunctions(emitter): |
| for aligned in [True, False]: |
| for lanes in range(1, 5): |
| for leftovers in range(0, 8): |
| GenerateZipNx8(emitter, lanes, leftovers, aligned) |
| emitter.EmitNewline() |