From 58914796c06662f4f901a4f195057ee1327cf055 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Tue, 16 Feb 2021 19:50:23 -0300
Subject: [PATCH] shader: Add XMAD multiplication folding optimization

---
 .../ir_opt/constant_propagation_pass.cpp      | 82 +++++++++++++++++--
 1 file changed, 77 insertions(+), 5 deletions(-)

diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index f1ad16d60..9eb61b54c 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -9,6 +9,7 @@
 #include "common/bit_cast.h"
 #include "common/bit_util.h"
 #include "shader_recompiler/exception.h"
+#include "shader_recompiler/frontend/ir/ir_emitter.h"
 #include "shader_recompiler/frontend/ir/microinstruction.h"
 #include "shader_recompiler/ir_opt/passes.h"
 
@@ -99,8 +100,71 @@ void FoldGetPred(IR::Inst& inst) {
     }
 }
 
+/// Replaces the pattern generated by two XMAD multiplications
+bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
+    /*
+     * We are looking for this pattern:
+     *   %rhs_bfe = BitFieldUExtract %factor_a, #0, #16 (uses: 1)
+     *   %rhs_mul = IMul32 %rhs_bfe, %factor_b (uses: 1)
+     *   %lhs_bfe = BitFieldUExtract %factor_a, #16, #16 (uses: 1)
+     *   %rhs_mul = IMul32 %lhs_bfe, %factor_b (uses: 1)
+     *   %lhs_shl = ShiftLeftLogical32 %rhs_mul, #16 (uses: 1)
+     *   %result  = IAdd32 %lhs_shl, %rhs_mul (uses: 10)
+     *
+     * And replacing it with
+     *   %result  = IMul32 %factor_a, %factor_b
+     *
+     * This optimization has been proven safe by LLVM and MSVC.
+     */
+    const IR::Value lhs_arg{inst.Arg(0)};
+    const IR::Value rhs_arg{inst.Arg(1)};
+    if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) {
+        return false;
+    }
+    IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
+    if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) {
+        return false;
+    }
+    if (lhs_shl->Arg(0).IsImmediate()) {
+        return false;
+    }
+    IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
+    IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
+    if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) {
+        return false;
+    }
+    if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
+        return false;
+    }
+    const IR::U32 factor_b{lhs_mul->Arg(1)};
+    if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) {
+        return false;
+    }
+    IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
+    IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
+    if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
+        return false;
+    }
+    if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
+        return false;
+    }
+    if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) {
+        return false;
+    }
+    if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
+        return false;
+    }
+    if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) {
+        return false;
+    }
+    const IR::U32 factor_a{lhs_bfe->Arg(0)};
+    IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
+    inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
+    return true;
+}
+
 template <typename T>
-void FoldAdd(IR::Inst& inst) {
+void FoldAdd(IR::Block& block, IR::Inst& inst) {
     if (inst.HasAssociatedPseudoOperation()) {
         return;
     }
@@ -110,6 +174,12 @@ void FoldAdd(IR::Inst& inst) {
     const IR::Value rhs{inst.Arg(1)};
     if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
         inst.ReplaceUsesWith(inst.Arg(0));
+        return;
+    }
+    if constexpr (std::is_same_v<T, u32>) {
+        if (FoldXmadMultiply(block, inst)) {
+            return;
+        }
     }
 }
 
@@ -244,14 +314,14 @@ void FoldBranchConditional(IR::Inst& inst) {
     }
 }
 
-void ConstantPropagation(IR::Inst& inst) {
+void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
     switch (inst.Opcode()) {
     case IR::Opcode::GetRegister:
         return FoldGetRegister(inst);
     case IR::Opcode::GetPred:
         return FoldGetPred(inst);
     case IR::Opcode::IAdd32:
-        return FoldAdd<u32>(inst);
+        return FoldAdd<u32>(block, inst);
     case IR::Opcode::ISub32:
         return FoldISub32(inst);
     case IR::Opcode::BitCastF32U32:
@@ -259,7 +329,7 @@ void ConstantPropagation(IR::Inst& inst) {
     case IR::Opcode::BitCastU32F32:
         return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32);
     case IR::Opcode::IAdd64:
-        return FoldAdd<u64>(inst);
+        return FoldAdd<u64>(block, inst);
     case IR::Opcode::Select32:
         return FoldSelect<u32>(inst);
     case IR::Opcode::LogicalAnd:
@@ -292,7 +362,9 @@ void ConstantPropagation(IR::Inst& inst) {
 } // Anonymous namespace
 
 void ConstantPropagationPass(IR::Block& block) {
-    std::ranges::for_each(block, ConstantPropagation);
+    for (IR::Inst& inst : block) {
+        ConstantPropagation(block, inst);
+    }
 }
 
 } // namespace Shader::Optimization