Ticket #803: shared_array_test-1.py

File shared_array_test-1.py, 3.7 KB (added by tic20@…, 8 years ago)

Added by email2trac

Line 
1from numpy import ndarray, frombuffer
2import multiprocessing as mp
3from multiprocessing import sharedctypes
4import ctypes
5from time import sleep
6
7def TypedMPArray(typecode_or_type, size_or_initializer, *, lock=True, ctx=None):
8 obj = mp.Array(typecode_or_type, size_or_initializer)
9
10 obj.dtype = mp.sharedctypes.typecode_to_type.get(typecode_or_type, typecode_or_type)
11
12 return obj
13
14
15class SharedNumpyArray(ndarray):
16 '''
17 multiprocessing.Array types are thread-safe by default, but are
18 horribly inefficient in getting/setting data. If you want speed you
19 need to create a Numpy array pointing to the same shared memory,
20 but this circumvents the automatic acquire/release behaviour. To
21 provide thread-safe behaviour would therefore require carrying through
22 both the original Array (for the lock) and the derived Numpy array.
23 This class is an attempt to get the best of both worlds: it behaves
24 just like a normal Numpy array, but carries through the Array lock
25 object and its methods. To use it:
26
27 (In master process):
28 import multiprocessing as mp
29 mp_array = mp.Array(type, data_or_init)
30 shared_numpy = SharedNumpyArray(mp_array)
31
32 Pass shared_numpy to the thread Pool init function or to the thread
33 itself if creating threads on the fly.
34
35 (In each thread):
36 If thread safety is not required (that is, different threads don't
37 attempt to read and write to the same index), then just use it like
38 any other array. If thread safety *is* required:
39
40 with shared_numpy.get_lock():
41 do_something(shared_numpy)
42 '''
43 def __new__(cls, mp_array):
44 if mp_array is None:
45 raise TypeError('Please provide a TypedMPArray object\
46 with a thread lock!')
47 obj = frombuffer(mp_array.get_obj(), mp_array.dtype).view(cls)
48 obj._mparray = mp_array
49 obj.get_lock = mp_array.get_lock
50 obj.acquire = mp_array.acquire
51 obj.release = mp_array.release
52 return obj
53
54 def __array_finalize__(self, obj):
55 if obj is None:
56 return
57
58 self._mparray = getattr(obj, '_mparray', None)
59 self.get_lock = getattr(obj, 'get_lock', None)
60 self.acquire = getattr(obj, 'acquire', None)
61 self.release = getattr(obj, 'release', None)
62
63def error_callback(e):
64 print(e)
65
66
67def pool_init(arr):
68 global shared_array
69 shared_array = arr
70
71def thread_safe_add_one():
72 global shared_array
73 for i in range(50):
74 with shared_array.get_lock():
75 shared_arr_cache = shared_array.copy()
76 sleep(1e-3)
77 shared_array[:] = shared_arr_cache+1
78
79def thread_unsafe_add_one():
80 global shared_array
81 for i in range(50):
82 shared_arr_cache = shared_array.copy()
83 sleep(1e-3)
84 shared_array[:] = shared_arr_cache+1
85
86mp_arr = TypedMPArray(ctypes.c_longlong, 10)
87print(mp_arr.dtype)
88s_arr = SharedNumpyArray(mp_arr).reshape((2,5))
89
90s_arr = SharedNumpyArray(TypedMPArray(ctypes.c_long, 10)).reshape((2,5))
91with mp.Pool(processes = 3, initializer = pool_init, initargs = (s_arr,)) as p:
92 for i in range(3):
93 p.apply_async(thread_safe_add_one,
94 args=(),
95 error_callback=error_callback)
96 p.close()
97 p.join()
98 print('Values should all equal 150')
99 print(s_arr)
100
101s_arr [:]=0
102
103with mp.Pool(processes = 3, initializer = pool_init, initargs = (s_arr,)) as p:
104 for i in range(3):
105 p.apply_async(thread_unsafe_add_one,
106 args=(),
107 error_callback=error_callback)
108 p.close()
109 p.join()
110 print('Values should NOT all equal 150')
111 print(s_arr)
112
113