From 2044a289f8781199c745a040056b56543fd7ed0e Mon Sep 17 00:00:00 2001
From: Liam <byteslice@airmail.cc>
Date: Wed, 10 Jan 2024 22:58:18 -0500
Subject: [PATCH] shader_recompiler: fix Offset operand usage for
 non-OpImage*Gather

---
 .../backend/spirv/emit_spirv_image.cpp        | 76 +++++++++++++------
 .../backend/spirv/emit_spirv_instructions.h   |  2 +-
 2 files changed, 55 insertions(+), 23 deletions(-)

diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp
index 800754554..64a4e0e55 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp
@@ -12,6 +12,11 @@ namespace Shader::Backend::SPIRV {
 namespace {
 class ImageOperands {
 public:
+    [[maybe_unused]] static constexpr bool ImageSampleOffsetAllowed = false;
+    [[maybe_unused]] static constexpr bool ImageGatherOffsetAllowed = true;
+    [[maybe_unused]] static constexpr bool ImageFetchOffsetAllowed = false;
+    [[maybe_unused]] static constexpr bool ImageGradientOffsetAllowed = false;
+
     explicit ImageOperands(EmitContext& ctx, bool has_bias, bool has_lod, bool has_lod_clamp,
                            Id lod, const IR::Value& offset) {
         if (has_bias) {
@@ -22,7 +27,7 @@ public:
             const Id lod_value{has_lod_clamp ? ctx.OpCompositeExtract(ctx.F32[1], lod, 0) : lod};
             Add(spv::ImageOperandsMask::Lod, lod_value);
         }
-        AddOffset(ctx, offset);
+        AddOffset(ctx, offset, ImageSampleOffsetAllowed);
         if (has_lod_clamp) {
             const Id lod_clamp{has_bias ? ctx.OpCompositeExtract(ctx.F32[1], lod, 1) : lod};
             Add(spv::ImageOperandsMask::MinLod, lod_clamp);
@@ -55,20 +60,17 @@ public:
         Add(spv::ImageOperandsMask::ConstOffsets, offsets);
     }
 
-    explicit ImageOperands(Id offset, Id lod, Id ms) {
+    explicit ImageOperands(Id lod, Id ms) {
         if (Sirit::ValidId(lod)) {
             Add(spv::ImageOperandsMask::Lod, lod);
         }
-        if (Sirit::ValidId(offset)) {
-            Add(spv::ImageOperandsMask::Offset, offset);
-        }
         if (Sirit::ValidId(ms)) {
             Add(spv::ImageOperandsMask::Sample, ms);
         }
     }
 
     explicit ImageOperands(EmitContext& ctx, bool has_lod_clamp, Id derivatives,
-                           u32 num_derivatives, Id offset, Id lod_clamp) {
+                           u32 num_derivatives, const IR::Value& offset, Id lod_clamp) {
         if (!Sirit::ValidId(derivatives)) {
             throw LogicError("Derivatives must be present");
         }
@@ -83,16 +85,14 @@ public:
         const Id derivatives_Y{ctx.OpCompositeConstruct(
             ctx.F32[num_derivatives], std::span{deriv_y_accum.data(), deriv_y_accum.size()})};
         Add(spv::ImageOperandsMask::Grad, derivatives_X, derivatives_Y);
-        if (Sirit::ValidId(offset)) {
-            Add(spv::ImageOperandsMask::Offset, offset);
-        }
+        AddOffset(ctx, offset, ImageGradientOffsetAllowed);
         if (has_lod_clamp) {
             Add(spv::ImageOperandsMask::MinLod, lod_clamp);
         }
     }
 
     explicit ImageOperands(EmitContext& ctx, bool has_lod_clamp, Id derivatives_1, Id derivatives_2,
-                           Id offset, Id lod_clamp) {
+                           const IR::Value& offset, Id lod_clamp) {
         if (!Sirit::ValidId(derivatives_1) || !Sirit::ValidId(derivatives_2)) {
             throw LogicError("Derivatives must be present");
         }
@@ -111,9 +111,7 @@ public:
         const Id derivatives_id2{ctx.OpCompositeConstruct(
             ctx.F32[3], std::span{deriv_2_accum.data(), deriv_2_accum.size()})};
         Add(spv::ImageOperandsMask::Grad, derivatives_id1, derivatives_id2);
-        if (Sirit::ValidId(offset)) {
-            Add(spv::ImageOperandsMask::Offset, offset);
-        }
+        AddOffset(ctx, offset, ImageGradientOffsetAllowed);
         if (has_lod_clamp) {
             Add(spv::ImageOperandsMask::MinLod, lod_clamp);
         }
@@ -132,7 +130,7 @@ public:
     }
 
 private:
-    void AddOffset(EmitContext& ctx, const IR::Value& offset) {
+    void AddOffset(EmitContext& ctx, const IR::Value& offset, bool runtime_offset_allowed) {
         if (offset.IsEmpty()) {
             return;
         }
@@ -165,7 +163,9 @@ private:
                 break;
             }
         }
-        Add(spv::ImageOperandsMask::Offset, ctx.Def(offset));
+        if (runtime_offset_allowed) {
+            Add(spv::ImageOperandsMask::Offset, ctx.Def(offset));
+        }
     }
 
     void Add(spv::ImageOperandsMask new_mask, Id value) {
@@ -311,6 +311,37 @@ Id ImageGatherSubpixelOffset(EmitContext& ctx, const IR::TextureInstInfo& info,
         return coords;
     }
 }
+
+void AddOffsetToCoordinates(EmitContext& ctx, const IR::TextureInstInfo& info, Id& coords,
+                            Id offset) {
+    if (!Sirit::ValidId(offset)) {
+        return;
+    }
+
+    Id result_type{};
+    switch (info.type) {
+    case TextureType::Buffer:
+    case TextureType::Color1D:
+    case TextureType::ColorArray1D: {
+        result_type = ctx.U32[1];
+        break;
+    }
+    case TextureType::Color2D:
+    case TextureType::Color2DRect:
+    case TextureType::ColorArray2D: {
+        result_type = ctx.U32[2];
+        break;
+    }
+    case TextureType::Color3D: {
+        result_type = ctx.U32[3];
+        break;
+    }
+    case TextureType::ColorCube:
+    case TextureType::ColorArrayCube:
+        return;
+    }
+    coords = ctx.OpIAdd(result_type, coords, offset);
+}
 } // Anonymous namespace
 
 Id EmitBindlessImageSampleImplicitLod(EmitContext&) {
@@ -496,6 +527,7 @@ Id EmitImageGatherDref(EmitContext& ctx, IR::Inst* inst, const IR::Value& index,
 Id EmitImageFetch(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords, Id offset,
                   Id lod, Id ms) {
     const auto info{inst->Flags<IR::TextureInstInfo>()};
+    AddOffsetToCoordinates(ctx, info, coords, offset);
     if (info.type == TextureType::Buffer) {
         lod = Id{};
     }
@@ -503,7 +535,7 @@ Id EmitImageFetch(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id c
         // This image is multisampled, lod must be implicit
         lod = Id{};
     }
-    const ImageOperands operands(offset, lod, ms);
+    const ImageOperands operands(lod, ms);
     return Emit(&EmitContext::OpImageSparseFetch, &EmitContext::OpImageFetch, ctx, inst, ctx.F32[4],
                 TextureImage(ctx, info, index), coords, operands.MaskOptional(), operands.Span());
 }
@@ -548,13 +580,13 @@ Id EmitImageQueryLod(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, I
 }
 
 Id EmitImageGradient(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords,
-                     Id derivatives, Id offset, Id lod_clamp) {
+                     Id derivatives, const IR::Value& offset, Id lod_clamp) {
     const auto info{inst->Flags<IR::TextureInstInfo>()};
-    const auto operands =
-        info.num_derivatives == 3
-            ? ImageOperands(ctx, info.has_lod_clamp != 0, derivatives, offset, {}, lod_clamp)
-            : ImageOperands(ctx, info.has_lod_clamp != 0, derivatives, info.num_derivatives, offset,
-                            lod_clamp);
+    const auto operands = info.num_derivatives == 3
+                              ? ImageOperands(ctx, info.has_lod_clamp != 0, derivatives,
+                                              ctx.Def(offset), {}, lod_clamp)
+                              : ImageOperands(ctx, info.has_lod_clamp != 0, derivatives,
+                                              info.num_derivatives, offset, lod_clamp);
     return Emit(&EmitContext::OpImageSparseSampleExplicitLod,
                 &EmitContext::OpImageSampleExplicitLod, ctx, inst, ctx.F32[4],
                 Texture(ctx, info, index), coords, operands.Mask(), operands.Span());
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
index 7d34575c8..5c01b1012 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
@@ -543,7 +543,7 @@ Id EmitImageQueryDimensions(EmitContext& ctx, IR::Inst* inst, const IR::Value& i
                             const IR::Value& skip_mips);
 Id EmitImageQueryLod(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords);
 Id EmitImageGradient(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords,
-                     Id derivatives, Id offset, Id lod_clamp);
+                     Id derivatives, const IR::Value& offset, Id lod_clamp);
 Id EmitImageRead(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords);
 void EmitImageWrite(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords, Id color);
 Id EmitIsTextureScaled(EmitContext& ctx, const IR::Value& index);