-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_multiqueue_worker.py
48 lines (35 loc) · 1016 Bytes
/
test_multiqueue_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import time
from torch.multiprocessing import Process, Queue
from multiqueue_worker import multiqueue_worker
def foo():
time.sleep(5)
return f"foo ({os.getpid()})"
def bar():
time.sleep(1)
return f"bar ({os.getpid()})"
def test_multiqueue_worker():
# Initialize torch
init_torch_kwargs = {
"allow_tf32": False,
"benchmark": False,
"deterministic": True,
}
num_workers = 8
num_devices = 2
in_queues = [Queue(), Queue()]
out_queue = Queue()
for i in range(num_workers):
device = i % num_devices
args = (device, init_torch_kwargs, in_queues, out_queue)
Process(target=multiqueue_worker, args=args).start()
for _ in range(4):
in_queues[0].put((foo, ()))
for _ in range(12):
in_queues[1].put((bar, ()))
for _ in range(16):
print(out_queue.get())
for i in range(num_workers):
in_queues[0].put(None)
if __name__ == "__main__":
test_multiqueue_worker()