From da34d3704405665b68d3d992f37a7eeb541238af Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Tue, 25 May 2021 20:37:06 -0300
Subject: [PATCH] common/thread_worker: Add support for stateful threads

---
 src/common/CMakeLists.txt    |  1 -
 src/common/thread_worker.cpp | 66 ------------------------
 src/common/thread_worker.h   | 97 ++++++++++++++++++++++++++++++++----
 3 files changed, 86 insertions(+), 78 deletions(-)
 delete mode 100644 src/common/thread_worker.cpp

diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt
index c05b78cd5..e03fffd8d 100644
--- a/src/common/CMakeLists.txt
+++ b/src/common/CMakeLists.txt
@@ -180,7 +180,6 @@ add_library(common STATIC
     thread.cpp
     thread.h
     thread_queue_list.h
-    thread_worker.cpp
     thread_worker.h
     threadsafe_queue.h
     time_zone.cpp
diff --git a/src/common/thread_worker.cpp b/src/common/thread_worker.cpp
deleted file mode 100644
index 32be49b15..000000000
--- a/src/common/thread_worker.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2020 yuzu emulator team
-// Licensed under GPLv2 or any later version
-// Refer to the license.txt file included.
-
-#include "common/thread.h"
-#include "common/thread_worker.h"
-
-namespace Common {
-
-ThreadWorker::ThreadWorker(std::size_t num_workers, const std::string& name) {
-    workers_queued.store(static_cast<u64>(num_workers), std::memory_order_release);
-    const auto lambda = [this, thread_name{std::string{name}}] {
-        Common::SetCurrentThreadName(thread_name.c_str());
-
-        while (!stop) {
-            UniqueFunction<void> task;
-            {
-                std::unique_lock lock{queue_mutex};
-                if (requests.empty()) {
-                    wait_condition.notify_all();
-                }
-                condition.wait(lock, [this] { return stop || !requests.empty(); });
-                if (stop) {
-                    break;
-                }
-                task = std::move(requests.front());
-                requests.pop();
-            }
-            task();
-            work_done++;
-        }
-        workers_stopped++;
-        wait_condition.notify_all();
-    };
-    for (size_t i = 0; i < num_workers; ++i) {
-        threads.emplace_back(lambda);
-    }
-}
-
-ThreadWorker::~ThreadWorker() {
-    {
-        std::unique_lock lock{queue_mutex};
-        stop = true;
-    }
-    condition.notify_all();
-    for (std::thread& thread : threads) {
-        thread.join();
-    }
-}
-
-void ThreadWorker::QueueWork(UniqueFunction<void> work) {
-    {
-        std::unique_lock lock{queue_mutex};
-        requests.emplace(std::move(work));
-        work_scheduled++;
-    }
-    condition.notify_one();
-}
-
-void ThreadWorker::WaitForRequests() {
-    std::unique_lock lock{queue_mutex};
-    wait_condition.wait(
-        lock, [this] { return workers_stopped >= workers_queued || work_done >= work_scheduled; });
-}
-
-} // namespace Common
diff --git a/src/common/thread_worker.h b/src/common/thread_worker.h
index 12bbf5fef..16aa673bd 100644
--- a/src/common/thread_worker.h
+++ b/src/common/thread_worker.h
@@ -8,32 +8,107 @@
 #include <functional>
 #include <mutex>
 #include <string>
+#include <type_traits>
 #include <vector>
 #include <queue>
 
-#include "common/common_types.h"
+#include "common/thread.h"
 #include "common/unique_function.h"
 
 namespace Common {
 
-class ThreadWorker final {
+template <class StateType = void>
+class StatefulThreadWorker {
+    static constexpr bool with_state = !std::is_same_v<StateType, void>;
+
+    struct DummyCallable {
+        int operator()() const noexcept {
+            return 0;
+        }
+    };
+
+    using Task =
+        std::conditional_t<with_state, UniqueFunction<void, StateType*>, UniqueFunction<void>>;
+    using StateMaker = std::conditional_t<with_state, std::function<StateType()>, DummyCallable>;
+
 public:
-    explicit ThreadWorker(std::size_t num_workers, const std::string& name);
-    ~ThreadWorker();
-    void QueueWork(UniqueFunction<void> work);
-    void WaitForRequests();
+    explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {})
+        : workers_queued{num_workers}, thread_name{std::move(name)} {
+        const auto lambda = [this, func] {
+            Common::SetCurrentThreadName(thread_name.c_str());
+            {
+                std::conditional_t<with_state, StateType, int> state{func()};
+                while (!stop) {
+                    Task task;
+                    {
+                        std::unique_lock lock{queue_mutex};
+                        if (requests.empty()) {
+                            wait_condition.notify_all();
+                        }
+                        condition.wait(lock, [this] { return stop || !requests.empty(); });
+                        if (stop) {
+                            break;
+                        }
+                        task = std::move(requests.front());
+                        requests.pop();
+                    }
+                    if constexpr (with_state) {
+                        task(&state);
+                    } else {
+                        task();
+                    }
+                    ++work_done;
+                }
+            }
+            ++workers_stopped;
+            wait_condition.notify_all();
+        };
+        for (size_t i = 0; i < num_workers; ++i) {
+            threads.emplace_back(lambda);
+        }
+    }
+
+    ~StatefulThreadWorker() {
+        {
+            std::unique_lock lock{queue_mutex};
+            stop = true;
+        }
+        condition.notify_all();
+        for (std::thread& thread : threads) {
+            thread.join();
+        }
+    }
+
+    void QueueWork(Task work) {
+        {
+            std::unique_lock lock{queue_mutex};
+            requests.emplace(std::move(work));
+            ++work_scheduled;
+        }
+        condition.notify_one();
+    }
+
+    void WaitForRequests() {
+        std::unique_lock lock{queue_mutex};
+        wait_condition.wait(lock, [this] {
+            return workers_stopped >= workers_queued || work_done >= work_scheduled;
+        });
+    }
 
 private:
     std::vector<std::thread> threads;
-    std::queue<UniqueFunction<void>> requests;
+    std::queue<Task> requests;
     std::mutex queue_mutex;
     std::condition_variable condition;
     std::condition_variable wait_condition;
     std::atomic_bool stop{};
-    std::atomic<u64> work_scheduled{};
-    std::atomic<u64> work_done{};
-    std::atomic<u64> workers_stopped{};
-    std::atomic<u64> workers_queued{};
+    std::atomic<size_t> work_scheduled{};
+    std::atomic<size_t> work_done{};
+    std::atomic<size_t> workers_stopped{};
+    std::atomic<size_t> workers_queued{};
+    std::string thread_name;
 };
 
+using ThreadWorker = StatefulThreadWorker<>;
+
 } // namespace Common