This is a follow up to my previous post. I've made a number of improvements to the thread pool and corrected some bugs as well.
The most up to date version of the code is available on my Github.
I have since removed the use of std::binary_semaphore
and instead moved to using std::condition_variable_any
. I was playing around with using a std::counting_semaphore
instead but couldn't figure out a good way to do so.
thread_pool.h
#pragma once
#include <concepts>
#include <functional>
#include <future>
#include <memory>
#include <queue>
#include <thread>
#include <type_traits>
#include "thread_pool/thread_safe_queue.h"
namespace dp {
namespace detail {
template <class T>
std::decay_t<T> decay_copy(T &&v) {
return std::forward<T>(v);
}
// bind F and parameter pack into a nullary one shot. Lambda captures by value.
template <typename... Args, typename F>
auto bind(F &&f, Args &&...args) {
return [f = decay_copy(std::forward<F>(f)),
... args = decay_copy(std::forward<Args>(args))]() mutable -> decltype(auto) {
return std::invoke(std::move(f), std::move(args)...);
};
}
} // namespace detail
template <typename FunctionType = std::function<void()>>
requires std::invocable<FunctionType> &&
std::is_same_v<void, std::invoke_result_t<FunctionType>>
class thread_pool {
public:
thread_pool(const unsigned int &number_of_threads = std::thread::hardware_concurrency()) {
for (std::size_t i = 0; i < number_of_threads; ++i) {
threads_.emplace_back([&](const std::stop_token stop_tok) {
do {
// check if we have task
if (queue_.empty()) {
// no tasks, so we wait instead of spinning
std::unique_lock lock(condition_mutex_);
condition_.wait(lock, stop_tok, [this]() { return !queue_.empty(); });
}
// ensure we have a task before getting task
// since the dtor notifies via the condition variable as well
if (!queue_.empty()) {
// get the task
auto task = queue_.pop();
// invoke the task
std::invoke(std::move(task));
// decrement in-flight counter
--in_flight_;
}
} while (!stop_tok.stop_requested());
});
}
}
~thread_pool() {
// wait for tasks to complete first
do {
std::this_thread::yield();
} while (in_flight_ > 0);
// stop all threads
for (auto &thread : threads_) {
thread.request_stop();
}
condition_.notify_all();
}
/// thread pool is non-copyable
thread_pool(const thread_pool &) = delete;
thread_pool &operator=(const thread_pool &) = delete;
template <typename Function, typename... Args,
typename ReturnType = std::invoke_result_t<Function &&, Args &&...>>
requires std::invocable<Function, Args...>
[[nodiscard]] std::future<ReturnType> enqueue(Function f, Args... args) {
/*
* use shared promise here so that we don't break the promise later (until C++23)
*
* with C++23 we can do the following:
*
* std::promise<ReturnType> promise;
* auto future = promise.get_future();
* auto task = [func = std::move(f), ...largs = std::move(args),
promise = std::move(promise)]() mutable {...};
*/
auto shared_promise = std::make_shared<std::promise<ReturnType>>();
auto task = [func = std::move(f), ... largs = std::move(args),
promise = shared_promise]() { promise->set_value(func(largs...)); };
// get the future before enqueuing the task
auto future = shared_promise->get_future();
// enqueue the task
enqueue_task(std::move(task));
return future;
}
template <typename Function, typename... Args>
requires std::invocable<Function, Args...> &&
std::is_same_v<void, std::invoke_result_t<Function &&, Args &&...>>
void enqueue_detach(Function &&func, Args &&...args) {
enqueue_task(detail::bind(std::forward<Function>(func), std::forward<Args>(args)...));
}
private:
template <typename Function>
void enqueue_task(Function &&f) {
++in_flight_;
{
std::lock_guard lock(condition_mutex_);
queue_.push(std::forward<Function>(f));
}
condition_.notify_all();
}
std::condition_variable_any condition_;
std::mutex condition_mutex_;
std::vector<std::jthread> threads_;
dp::thread_safe_queue<FunctionType> queue_;
std::atomic<int64_t> in_flight_{0};
};
} // namespace dp
Again for clarity, below is my thread safe queue implementation:
thread_safe_queue.h
#pragma once
#include <condition_variable>
#include <deque>
#include <mutex>
namespace dp {
template <typename T>
class thread_safe_queue {
public:
using value_type = T;
using size_type = typename std::deque<T>::size_type;
thread_safe_queue() = default;
void push(T&& value) {
{
std::lock_guard lock(mutex_);
data_.push_back(std::forward<T>(value));
}
condition_variable_.notify_all();
}
bool empty() {
std::lock_guard lock(mutex_);
return data_.empty();
}
[[nodiscard]] size_type size() {
std::lock_guard lock(mutex_);
return data_.size();
}
[[nodiscard]] T pop() {
std::unique_lock lock(mutex_);
condition_variable_.wait(lock, [this] { return !data_.empty(); });
auto front = data_.front();
data_.pop_front();
return front;
}
private:
using mutex_type = std::mutex;
std::deque<T> data_;
mutable mutex_type mutex_{};
std::condition_variable condition_variable_{};
};
} // namespace dp
Example driver code:
#include <thread_pool/thread_pool.h>
dp::thread_pool pool(4);
const auto total_tasks = 30;
std::vector<std::future<int>> futures;
for (auto i = 0; i < total_tasks; i++) {
auto task = [index = i]() { return index; };
futures.push_back(pool.enqueue(task));
}
Any and all feedback is much appreciated. Would you use this implementation in one of your projects? If not, please share! I'm curious to hear where this can be improved. My goal is to have something that is not only performant, but reliable and "bulletproof".