| /* |
| * Copyright 2023 Valve Corporation |
| * SPDX-License-Identifier: MIT |
| */ |
| #include "nir.h" |
| #include "nir_builder.h" |
| #include "nir_worklist.h" |
| |
| /* |
| * This pass recognizes certain patterns of nir_op_shfr and nir_op_msad_4x8 and replaces it |
| * with a single nir_op_mqsad_4x8 instruction. |
| */ |
| |
| struct mqsad { |
| nir_scalar ref; |
| nir_scalar src[2]; |
| |
| nir_scalar accum[4]; |
| nir_alu_instr *msad[4]; |
| unsigned first_msad_index; |
| uint8_t mask; |
| }; |
| |
| static bool |
| is_mqsad_compatible(struct mqsad *mqsad, nir_scalar ref, nir_scalar src0, nir_scalar src1, |
| unsigned idx, nir_alu_instr *msad) |
| { |
| if (!nir_scalar_equal(ref, mqsad->ref) || !nir_scalar_equal(src0, mqsad->src[0])) |
| return false; |
| if ((mqsad->mask & 0b1110) && idx && !nir_scalar_equal(src1, mqsad->src[1])) |
| return false; |
| |
| /* Ensure that this MSAD doesn't depend on any previous MSAD. */ |
| nir_instr_worklist *wl = nir_instr_worklist_create(); |
| nir_instr_worklist_add_ssa_srcs(wl, &msad->instr); |
| nir_foreach_instr_in_worklist(instr, wl) { |
| if (instr->block != msad->instr.block || instr->index < mqsad->first_msad_index) |
| continue; |
| |
| u_foreach_bit(i, mqsad->mask) { |
| if (instr == &mqsad->msad[i]->instr) { |
| nir_instr_worklist_destroy(wl); |
| return false; |
| } |
| } |
| |
| nir_instr_worklist_add_ssa_srcs(wl, instr); |
| } |
| nir_instr_worklist_destroy(wl); |
| |
| return true; |
| } |
| |
| static void |
| parse_msad(nir_alu_instr *msad, struct mqsad *mqsad) |
| { |
| if (msad->def.num_components != 1) |
| return; |
| |
| nir_scalar msad_s = nir_get_scalar(&msad->def, 0); |
| nir_scalar ref = nir_scalar_chase_alu_src(msad_s, 0); |
| nir_scalar accum = nir_scalar_chase_alu_src(msad_s, 2); |
| |
| unsigned idx = 0; |
| nir_scalar src0 = nir_scalar_chase_alu_src(msad_s, 1); |
| nir_scalar src1; |
| if (nir_scalar_is_alu(src0) && nir_scalar_alu_op(src0) == nir_op_shfr) { |
| nir_scalar amount_s = nir_scalar_chase_alu_src(src0, 2); |
| uint32_t amount = nir_scalar_is_const(amount_s) ? nir_scalar_as_uint(amount_s) : 0; |
| if (amount == 8 || amount == 16 || amount == 24) { |
| idx = amount / 8; |
| src1 = nir_scalar_chase_alu_src(src0, 0); |
| src0 = nir_scalar_chase_alu_src(src0, 1); |
| } |
| } |
| |
| if (mqsad->mask && !is_mqsad_compatible(mqsad, ref, src0, src1, idx, msad)) |
| memset(mqsad, 0, sizeof(*mqsad)); |
| |
| /* Add this instruction to the in-progress MQSAD. */ |
| mqsad->ref = ref; |
| mqsad->src[0] = src0; |
| if (idx) |
| mqsad->src[1] = src1; |
| |
| mqsad->accum[idx] = accum; |
| mqsad->msad[idx] = msad; |
| if (!mqsad->mask) |
| mqsad->first_msad_index = msad->instr.index; |
| mqsad->mask |= 1 << idx; |
| } |
| |
| static void |
| create_msad(nir_builder *b, struct mqsad *mqsad) |
| { |
| nir_def *mqsad_def = nir_mqsad_4x8(b, nir_channel(b, mqsad->ref.def, mqsad->ref.comp), |
| nir_vec_scalars(b, mqsad->src, 2), |
| nir_vec_scalars(b, mqsad->accum, 4)); |
| |
| for (unsigned i = 0; i < 4; i++) |
| nir_def_rewrite_uses(&mqsad->msad[i]->def, nir_channel(b, mqsad_def, i)); |
| |
| memset(mqsad, 0, sizeof(*mqsad)); |
| } |
| |
| bool |
| nir_opt_mqsad(nir_shader *shader) |
| { |
| bool progress = false; |
| nir_foreach_function_impl(impl, shader) { |
| bool progress_impl = false; |
| |
| nir_metadata_require(impl, nir_metadata_instr_index); |
| |
| nir_foreach_block(block, impl) { |
| struct mqsad mqsad; |
| memset(&mqsad, 0, sizeof(mqsad)); |
| |
| nir_foreach_instr(instr, block) { |
| if (instr->type != nir_instr_type_alu) |
| continue; |
| |
| nir_alu_instr *alu = nir_instr_as_alu(instr); |
| if (alu->op != nir_op_msad_4x8) |
| continue; |
| |
| parse_msad(alu, &mqsad); |
| |
| if (mqsad.mask == 0xf) { |
| nir_builder b = nir_builder_at(nir_before_instr(instr)); |
| create_msad(&b, &mqsad); |
| progress_impl = true; |
| } |
| } |
| } |
| |
| if (progress_impl) { |
| nir_metadata_preserve(impl, nir_metadata_control_flow); |
| progress = true; |
| } else { |
| nir_metadata_preserve(impl, nir_metadata_block_index); |
| } |
| } |
| |
| return progress; |
| } |