Ticket #803: shared_array_test.py

File shared_array_test.py, 3.3 KB (added by Tristan Croll, 8 years ago)

Thread-safe shared Numpy arrays

Line 
1import numpy
2from numpy import ndarray, frombuffer
3import multiprocessing as mp
4from time import sleep
5
6class SharedNumpyArray(ndarray):
7 '''
8 multiprocessing.Array types are thread-safe by default, but are
9 horribly inefficient in getting/setting data. If you want speed you
10 need to create a Numpy array pointing to the same shared memory,
11 but this circumvents the automatic acquire/release behaviour. To
12 provide thread-safe behaviour would therefore require carrying through
13 both the original Array (for the lock) and the derived Numpy array.
14 This class is an attempt to get the best of both worlds: it behaves
15 just like a normal Numpy array, but carries through the Array lock
16 object and its methods. To use it:
17
18 (In master process):
19 import multiprocessing as mp
20 mp_array = mp.Array(type, data_or_init)
21 shared_numpy = SharedNumpyArray(mp_array)
22
23 Pass shared_numpy to the thread Pool init function or to the thread
24 itself if creating threads on the fly.
25
26 (In each thread):
27 If thread safety is not required (that is, different threads don't
28 attempt to read and write to the same index), then just use it like
29 any other array. If thread safety *is* required:
30
31 with shared_numpy.get_lock():
32 do_something(shared_numpy)
33 '''
34 def __new__(cls, mp_array):
35 if mp_array is None:
36 raise TypeError('Please provide a multiprocessing.Array object\
37 with a thread lock!')
38 obj = frombuffer(mp_array.get_obj(), type(mp_array[0])).view(cls)
39 obj._mparray = mp_array
40 obj.get_lock = mp_array.get_lock
41 obj.acquire = mp_array.acquire
42 obj.release = mp_array.release
43 return obj
44
45 def __array_finalize__(self, obj):
46 if obj is None:
47 return
48
49 self._mparray = getattr(obj, '_mparray', None)
50 self.get_lock = getattr(obj, 'get_lock', None)
51 self.acquire = getattr(obj, 'acquire', None)
52 self.release = getattr(obj, 'release', None)
53
54def error_callback(e):
55 print(e)
56
57
58def pool_init(arr):
59 global shared_array
60 shared_array = arr
61
62def thread_safe_add_one():
63 global shared_array
64 for i in range(50):
65 with shared_array.get_lock():
66 shared_arr_cache = shared_array.copy()
67 sleep(1e-3)
68 shared_array[:] = shared_arr_cache+1
69
70def thread_unsafe_add_one():
71 global shared_array
72 for i in range(50):
73 shared_arr_cache = shared_array.copy()
74 sleep(1e-3)
75 shared_array[:] = shared_arr_cache+1
76
77mp_arr = mp.Array('d', 10)
78s_arr = SharedNumpyArray(mp_arr).reshape((2,5))
79
80with mp.Pool(processes = 3, initializer = pool_init, initargs = (s_arr,)) as p:
81 for i in range(3):
82 p.apply_async(thread_safe_add_one,
83 args=(),
84 error_callback=error_callback)
85 p.close()
86 p.join()
87 print('Values should all equal 150')
88 print(s_arr)
89
90s_arr [:]=0
91
92with mp.Pool(processes = 3, initializer = pool_init, initargs = (s_arr,)) as p:
93 for i in range(3):
94 p.apply_async(thread_unsafe_add_one,
95 args=(),
96 error_callback=error_callback)
97 p.close()
98 p.join()
99 print('Values should NOT all equal 150')
100 print(s_arr)
101
102