1 /**
2 	Multi-threaded task pool implementation.
3 
4 	Copyright: © 2012-2020 Sönke Ludwig
5 	License: Subject to the terms of the MIT license, as written in the included LICENSE.txt file.
6 	Authors: Sönke Ludwig
7 */
8 module vibe.core.taskpool;
9 
10 import vibe.core.concurrency : isWeaklyIsolated;
11 import vibe.core.core : exitEventLoop, logicalProcessorCount, runEventLoop, runTask, runTask_internal;
12 import vibe.core.log;
13 import vibe.core.sync : ManualEvent, Monitor, createSharedManualEvent, createMonitor;
14 import vibe.core.task : Task, TaskFuncInfo, TaskSettings, callWithMove;
15 import core.sync.mutex : Mutex;
16 import core.thread : Thread;
17 import std.concurrency : prioritySend, receiveOnly;
18 import std.traits : isFunctionPointer;
19 
20 
21 /** Implements a shared, multi-threaded task pool.
22 */
23 shared final class TaskPool {
24 	private {
25 		struct State {
26 			WorkerThread[] threads;
27 			TaskQueue queue;
28 			bool term;
29 		}
30 		vibe.core.sync.Monitor!(State, shared(Mutex)) m_state;
31 		shared(ManualEvent) m_signal;
32 		immutable size_t m_threadCount;
33 	}
34 
35 	/** Creates a new task pool with the specified number of threads.
36 
37 		Params:
38 			thread_count = The number of worker threads to create
39 	*/
40 	this(size_t thread_count = logicalProcessorCount())
41 	@safe {
42 		import std.format : format;
43 
44 		m_threadCount = thread_count;
45 		m_signal = createSharedManualEvent();
46 		m_state = createMonitor!State(new shared Mutex);
47 
48 		with (m_state.lock) {
49 			queue.setup();
50 			threads.length = thread_count;
51 			foreach (i; 0 .. thread_count) {
52 				WorkerThread thr;
53 				() @trusted {
54 					thr = new WorkerThread(this);
55 					thr.name = format("vibe-%s", i);
56 					thr.start();
57 				} ();
58 				threads[i] = thr;
59 			}
60 		}
61 	}
62 
63 	/** Returns the number of worker threads.
64 	*/
65 	@property size_t threadCount() const shared { return m_threadCount; }
66 
67 	/** Instructs all worker threads to terminate and waits until all have
68 		finished.
69 	*/
70 	void terminate()
71 	@safe nothrow {
72 		m_state.lock.term = true;
73 		m_signal.emit();
74 
75 		while (true) {
76 			WorkerThread th;
77 			with (m_state.lock)
78 				if (threads.length) {
79 					th = threads[0];
80 					threads = threads[1 .. $];
81 				}
82 			if (!th) break;
83 
84 			if (th is Thread.getThis())
85 				continue;
86 
87 			() @trusted {
88 				try th.join();
89 				catch (Exception e) {
90 					logWarn("Failed to wait for worker thread exit: %s", e.msg);
91 				}
92 			} ();
93 		}
94 
95 		size_t cnt = m_state.lock.queue.length;
96 		if (cnt > 0) logWarn("There were still %d worker tasks pending at exit.", cnt);
97 	}
98 
99 	/** Instructs all worker threads to terminate as soon as all tasks have
100 		been processed and waits for them to finish.
101 	*/
102 	void join()
103 	@safe nothrow {
104 		assert(false, "TODO!");
105 	}
106 
107 	/** Runs a new asynchronous task in a worker thread.
108 
109 		Only function pointers with weakly isolated arguments are allowed to be
110 		able to guarantee thread-safety.
111 	*/
112 	void runTask(FT, ARGS...)(FT func, auto ref ARGS args)
113 		if (isFunctionPointer!FT)
114 	{
115 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
116 		runTask_unsafe(TaskSettings.init, func, args);
117 	}
118 	/// ditto
119 	void runTask(alias method, T, ARGS...)(shared(T) object, auto ref ARGS args)
120 		if (is(typeof(__traits(getMember, object, __traits(identifier, method)))))
121 	{
122 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
123 		auto func = &__traits(getMember, object, __traits(identifier, method));
124 		runTask_unsafe(TaskSettings.init, func, args);
125 	}
126 	/// ditto
127 	void runTask(FT, ARGS...)(TaskSettings settings, FT func, auto ref ARGS args)
128 		if (isFunctionPointer!FT)
129 	{
130 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
131 		runTask_unsafe(settings, func, args);
132 	}
133 	/// ditto
134 	void runTask(alias method, T, ARGS...)(TaskSettings settings, shared(T) object, auto ref ARGS args)
135 		if (is(typeof(__traits(getMember, object, __traits(identifier, method)))))
136 	{
137 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
138 		auto func = &__traits(getMember, object, __traits(identifier, method));
139 		runTask_unsafe(settings, func, args);
140 	}
141 
142 	/** Runs a new asynchronous task in a worker thread, returning the task handle.
143 
144 		This function will yield and wait for the new task to be created and started
145 		in the worker thread, then resume and return it.
146 
147 		Only function pointers with weakly isolated arguments are allowed to be
148 		able to guarantee thread-safety.
149 	*/
150 	Task runTaskH(FT, ARGS...)(FT func, auto ref ARGS args)
151 		if (isFunctionPointer!FT)
152 	{
153 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
154 
155 		// workaround for runWorkerTaskH to work when called outside of a task
156 		if (Task.getThis() == Task.init) {
157 			Task ret;
158 			.runTask({ ret = doRunTaskH(TaskSettings.init, func, args); }).join();
159 			return ret;
160 		} else return doRunTaskH(TaskSettings.init, func, args);
161 	}
162 	/// ditto
163 	Task runTaskH(alias method, T, ARGS...)(shared(T) object, auto ref ARGS args)
164 		if (is(typeof(__traits(getMember, object, __traits(identifier, method)))))
165 	{
166 		static void wrapper()(shared(T) object, ref ARGS args) {
167 			__traits(getMember, object, __traits(identifier, method))(args);
168 		}
169 		return runTaskH(&wrapper!(), object, args);
170 	}
171 	/// ditto
172 	Task runTaskH(FT, ARGS...)(TaskSettings settings, FT func, auto ref ARGS args)
173 		if (isFunctionPointer!FT)
174 	{
175 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
176 
177 		// workaround for runWorkerTaskH to work when called outside of a task
178 		if (Task.getThis() == Task.init) {
179 			Task ret;
180 			.runTask({ ret = doRunTaskH(settings, func, args); }).join();
181 			return ret;
182 		} else return doRunTaskH(settings, func, args);
183 	}
184 	/// ditto
185 	Task runTaskH(alias method, T, ARGS...)(TaskSettings settings, shared(T) object, auto ref ARGS args)
186 		if (is(typeof(__traits(getMember, object, __traits(identifier, method)))))
187 	{
188 		static void wrapper()(shared(T) object, ref ARGS args) {
189 			__traits(getMember, object, __traits(identifier, method))(args);
190 		}
191 		return runTaskH(settings, &wrapper!(), object, args);
192 	}
193 
194 	// NOTE: needs to be a separate function to avoid recursion for the
195 	//       workaround above, which breaks @safe inference
196 	private Task doRunTaskH(FT, ARGS...)(TaskSettings settings, FT func, ref ARGS args)
197 		if (isFunctionPointer!FT)
198 	{
199 		import std.typecons : Typedef;
200 
201 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
202 
203 		alias PrivateTask = Typedef!(Task, Task.init, __PRETTY_FUNCTION__);
204 		Task caller = Task.getThis();
205 
206 		assert(caller != Task.init, "runWorkderTaskH can currently only be called from within a task.");
207 		static void taskFun(Task caller, FT func, ARGS args) {
208 			PrivateTask callee = Task.getThis();
209 			caller.tid.prioritySend(callee);
210 			mixin(callWithMove!ARGS("func", "args"));
211 		}
212 		runTask_unsafe(settings, &taskFun, caller, func, args);
213 		return cast(Task)() @trusted { return receiveOnly!PrivateTask(); } ();
214 	}
215 
216 
217 	/** Runs a new asynchronous task in all worker threads concurrently.
218 
219 		This function is mainly useful for long-living tasks that distribute their
220 		work across all CPU cores. Only function pointers with weakly isolated
221 		arguments are allowed to be able to guarantee thread-safety.
222 
223 		The number of tasks started is guaranteed to be equal to
224 		`threadCount`.
225 	*/
226 	void runTaskDist(FT, ARGS...)(FT func, auto ref ARGS args)
227 		if (is(typeof(*func) == function))
228 	{
229 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
230 		runTaskDist_unsafe(TaskSettings.init, func, args);
231 	}
232 	/// ditto
233 	void runTaskDist(alias method, T, ARGS...)(shared(T) object, auto ref ARGS args)
234 	{
235 		auto func = &__traits(getMember, object, __traits(identifier, method));
236 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
237 
238 		runTaskDist_unsafe(TaskSettings.init, func, args);
239 	}
240 	/// ditto
241 	void runTaskDist(FT, ARGS...)(TaskSettings settings, FT func, auto ref ARGS args)
242 		if (is(typeof(*func) == function))
243 	{
244 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
245 		runTaskDist_unsafe(settings, func, args);
246 	}
247 	/// ditto
248 	void runTaskDist(alias method, T, ARGS...)(TaskSettings settings, shared(T) object, auto ref ARGS args)
249 	{
250 		auto func = &__traits(getMember, object, __traits(identifier, method));
251 		foreach (T; ARGS) static assert(isWeaklyIsolated!T, "Argument type "~T.stringof~" is not safe to pass between threads.");
252 
253 		runTaskDist_unsafe(settings, func, args);
254 	}
255 
256 	/** Runs a new asynchronous task in all worker threads and returns the handles.
257 
258 		`on_handle` is an alias to a callble that takes a `Task` as its only
259 		argument and is called for every task instance that gets created.
260 
261 		See_also: `runTaskDist`
262 	*/
263 	void runTaskDistH(HCB, FT, ARGS...)(scope HCB on_handle, FT func, auto ref ARGS args)
264 		if (!is(HCB == TaskSettings))
265 	{
266 		runTaskDistH(TaskSettings.init, on_handle, func, args);
267 	}
268 	/// ditto
269 	void runTaskDistH(HCB, FT, ARGS...)(TaskSettings settings, scope HCB on_handle, FT func, auto ref ARGS args)
270 	{
271 		// TODO: support non-copyable argument types using .move
272 		import std.concurrency : send, receiveOnly;
273 
274 		auto caller = Task.getThis();
275 
276 		// workaround to work when called outside of a task
277 		if (caller == Task.init) {
278 			.runTask({ runTaskDistH(on_handle, func, args); }).join();
279 			return;
280 		}
281 
282 		static void call(Task t, FT func, ARGS args) {
283 			t.tid.send(Task.getThis());
284 			func(args);
285 		}
286 		runTaskDist(settings, &call, caller, func, args);
287 
288 		foreach (i; 0 .. this.threadCount)
289 			on_handle(receiveOnly!Task);
290 	}
291 
292 	private void runTask_unsafe(CALLABLE, ARGS...)(TaskSettings settings, CALLABLE callable, ref ARGS args)
293 	{
294 		import std.traits : ParameterTypeTuple;
295 		import vibe.internal.traits : areConvertibleTo;
296 		import vibe.internal.typetuple;
297 
298 		alias FARGS = ParameterTypeTuple!CALLABLE;
299 		static assert(areConvertibleTo!(Group!ARGS, Group!FARGS),
300 			"Cannot convert arguments '"~ARGS.stringof~"' to function arguments '"~FARGS.stringof~"'.");
301 
302 		m_state.lock.queue.put(settings, callable, args);
303 		m_signal.emitSingle();
304 	}
305 
306 	private void runTaskDist_unsafe(CALLABLE, ARGS...)(TaskSettings settings, ref CALLABLE callable, ARGS args) // NOTE: no ref for args, to disallow non-copyable types!
307 	{
308 		import std.traits : ParameterTypeTuple;
309 		import vibe.internal.traits : areConvertibleTo;
310 		import vibe.internal.typetuple;
311 
312 		alias FARGS = ParameterTypeTuple!CALLABLE;
313 		static assert(areConvertibleTo!(Group!ARGS, Group!FARGS),
314 			"Cannot convert arguments '"~ARGS.stringof~"' to function arguments '"~FARGS.stringof~"'.");
315 
316 		{
317 			auto st = m_state.lock;
318 			foreach (thr; st.threads) {
319 				// create one TFI per thread to properly account for elaborate assignment operators/postblit
320 				thr.m_queue.put(settings, callable, args);
321 			}
322 		}
323 		m_signal.emit();
324 	}
325 }
326 
327 private final class WorkerThread : Thread {
328 	private {
329 		shared(TaskPool) m_pool;
330 		TaskQueue m_queue;
331 	}
332 
333 	this(shared(TaskPool) pool)
334 	{
335 		m_pool = pool;
336 		m_queue.setup();
337 		super(&main);
338 	}
339 
340 	private void main()
341 	nothrow {
342 		import core.stdc.stdlib : abort;
343 		import core.exception : InvalidMemoryOperationError;
344 		import std.encoding : sanitize;
345 
346 		try {
347 			if (m_pool.m_state.lock.term) return;
348 			logDebug("entering worker thread");
349 			handleWorkerTasks();
350 			logDebug("Worker thread exit.");
351 		} catch (Throwable th) {
352 			logFatal("Worker thread terminated due to uncaught error: %s", th.msg);
353 			logDebug("Full error: %s", th.toString().sanitize());
354 			abort();
355 		}
356 	}
357 
358 	private void handleWorkerTasks()
359 	nothrow @safe {
360 		import std.algorithm.iteration : filter;
361 		import std.algorithm.mutation : swap;
362 		import std.algorithm.searching : count;
363 		import std.array : array;
364 
365 		logTrace("worker thread enter");
366 		TaskFuncInfo taskfunc;
367 		auto emit_count = m_pool.m_signal.emitCount;
368 		while(true) {
369 			with (m_pool.m_state.lock) {
370 				logTrace("worker thread check");
371 
372 				if (term) break;
373 
374 				if (m_queue.consume(taskfunc)) {
375 					logTrace("worker thread got specific task");
376 				} else if (queue.consume(taskfunc)) {
377 					logTrace("worker thread got unspecific task");
378 				}
379 			}
380 
381 			if (taskfunc.func !is null)
382 				.runTask_internal!((ref tfi) { swap(tfi, taskfunc); });
383 			else emit_count = m_pool.m_signal.waitUninterruptible(emit_count);
384 		}
385 
386 		logTrace("worker thread exit");
387 
388 		if (!m_queue.empty)
389 			logWarn("Worker thread shuts down with specific worker tasks left in its queue.");
390 
391 		with (m_pool.m_state.lock) {
392 			threads = threads.filter!(t => t !is this).array;
393 			if (threads.length > 0 && !queue.empty)
394 				logWarn("Worker threads shut down with worker tasks still left in the queue.");
395 		}
396 	}
397 }
398 
399 private struct TaskQueue {
400 nothrow @safe:
401 	// TODO: avoid use of GC
402 
403 	import vibe.internal.array : FixedRingBuffer;
404 	FixedRingBuffer!TaskFuncInfo* m_queue;
405 
406 	void setup()
407 	{
408 		m_queue = new FixedRingBuffer!TaskFuncInfo;
409 	}
410 
411 	@property bool empty() const { return m_queue.empty; }
412 
413 	@property size_t length() const { return m_queue.length; }
414 
415 	void put(CALLABLE, ARGS...)(TaskSettings settings, ref CALLABLE c, ref ARGS args)
416 	{
417 		import std.algorithm.comparison : max;
418 		if (m_queue.full) m_queue.capacity = max(16, m_queue.capacity * 3 / 2);
419 		assert(!m_queue.full);
420 
421 		m_queue.peekDst[0].settings = settings;
422 		m_queue.peekDst[0].set(c, args);
423 		m_queue.putN(1);
424 	}
425 
426 	bool consume(ref TaskFuncInfo tfi)
427 	{
428 		import std.algorithm.mutation : swap;
429 
430 		if (m_queue.empty) return false;
431 		swap(tfi, m_queue.front);
432 		m_queue.popFront();
433 		return true;
434 	}
435 }