Coverage for tests/unit/test_parallel.py: 96%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1import sys 

2import pytest 

3import multiprocessing 

4import time 

5from typing import Any, List, Iterable 

6 

7# Import the function to test 

8from muutils.parallel import DEFAULT_PBAR_FN, run_maybe_parallel 

9 

10DATA: dict = dict( 

11 empty=[], 

12 single=[5], 

13 small=list(range(4)), 

14 medium=list(range(10)), 

15 large=list(range(50)), 

16) 

17SQUARE_RESULTS: dict = {k: [x**2 for x in v] for k, v in DATA.items()} 

18ADD_ONE_RESULTS: dict = {k: [x + 1 for x in v] for k, v in DATA.items()} 

19 

20 

21# Basic test functions 

22def square(x: int) -> int: 

23 return x**2 

24 

25 

26def add_one(x: int) -> int: 

27 return x + 1 

28 

29 

30def raise_value_error(x: int) -> int: 

31 if x == 5: 

32 raise ValueError("Test error") 

33 return x**2 

34 

35 

36def slow_square(x: int) -> int: 

37 time.sleep(0.0001) 

38 return x**2 

39 

40 

41def raise_on_negative(x: int) -> int: 

42 if x < 0: 

43 raise ValueError("Negative number") 

44 return x 

45 

46 

47def stateful_fn(x: list) -> list: 

48 x.append(1) 

49 return x 

50 

51 

52class ComplexObject: 

53 def __init__(self, value: int): 

54 self.value = value 

55 

56 def __eq__(self, other: Any) -> bool: 

57 return isinstance(other, ComplexObject) and self.value == other.value 

58 

59 

60def dataset_decorator(keys: List[str]): 

61 def wrapper(test_func): 

62 return pytest.mark.parametrize( 

63 "input_range, expected", 

64 [(DATA[k], SQUARE_RESULTS[k]) for k in keys], 

65 ids=keys, 

66 )(test_func) 

67 

68 return wrapper 

69 

70 

71@dataset_decorator(["empty", "single", "small"]) 

72@pytest.mark.parametrize("parallel", [False, True, 2, 4]) 

73@pytest.mark.parametrize("keep_ordered", [True, False]) 

74@pytest.mark.parametrize("use_multiprocess", [True, False]) 

75def test_general_functionality( 

76 input_range, expected, parallel, keep_ordered, use_multiprocess 

77): 

78 # whether it's possible to use multiprocess 

79 if use_multiprocess and ( 

80 parallel is False or parallel == 1 or len(input_range) == 1 

81 ): 

82 return 

83 

84 # run the function 

85 results = run_maybe_parallel( 

86 func=square, 

87 iterable=input_range, 

88 parallel=parallel, 

89 pbar_kwargs={}, 

90 keep_ordered=keep_ordered, 

91 use_multiprocess=use_multiprocess, 

92 ) 

93 

94 # check the results 

95 assert set(results) == set(expected) 

96 if keep_ordered: 

97 assert results == expected 

98 

99 

100@dataset_decorator(["small"]) 

101@pytest.mark.parametrize( 

102 "pbar_type", 

103 ["tqdm", "spinner", "none", None, "invalid"], 

104) 

105@pytest.mark.parametrize("disable_flag", [True, False]) 

106def test_progress_bar_types_and_disable(input_range, expected, pbar_type, disable_flag): 

107 pbar_kwargs = {"disable": disable_flag} 

108 if pbar_type == "invalid" and not disable_flag: 

109 with pytest.raises(ValueError): 

110 run_maybe_parallel(square, input_range, False, pbar_kwargs, pbar=pbar_type) 

111 else: 

112 results = run_maybe_parallel( 

113 square, input_range, False, pbar_kwargs, pbar=pbar_type 

114 ) 

115 assert results == expected 

116 

117 

118@dataset_decorator(["small"]) 

119@pytest.mark.parametrize("chunksize", [None, 1, 5]) 

120@pytest.mark.parametrize("parallel", [False, True, 2]) 

121def test_chunksize_and_parallel(input_range, expected, chunksize, parallel): 

122 results = run_maybe_parallel(square, input_range, parallel, {}, chunksize=chunksize) 

123 assert results == expected 

124 

125 

126@dataset_decorator(["small"]) 

127@pytest.mark.parametrize("invalid_parallel", ["invalid", 0, -1, 1.5]) 

128def test_invalid_parallel_values(input_range, expected, invalid_parallel): 

129 with pytest.raises(ValueError): 

130 run_maybe_parallel(square, input_range, invalid_parallel) 

131 

132 

133def test_exception_in_func(): 

134 # one of the inputs is 0..3, no error here 

135 # Let's inject a known error 

136 error_input = [5] # Will raise ValueError 

137 with pytest.raises(ValueError): 

138 run_maybe_parallel(raise_value_error, error_input, True, {}) 

139 

140 

141@dataset_decorator(["small"]) 

142@pytest.mark.parametrize( 

143 "iterable_factory", 

144 [ 

145 lambda x: list(x), 

146 lambda x: tuple(x), 

147 lambda x: set(x), 

148 lambda x: dict.fromkeys(x, 0), 

149 ], 

150) 

151def test_different_iterables(input_range, expected, iterable_factory): 

152 test_input = iterable_factory(input_range) 

153 result = run_maybe_parallel(square, test_input, False) 

154 if isinstance(test_input, set): 

155 assert set(result) == set(expected) 

156 else: 

157 assert result == expected 

158 

159 

160@pytest.mark.parametrize("parallel", [False, True]) 

161def test_error_handling(parallel): 

162 # input_range is all positive small range, let's modify it to include negatives 

163 input_data = [-1, 0, 1, -2] 

164 with pytest.raises(ValueError): 

165 run_maybe_parallel(raise_on_negative, input_data, parallel) 

166 

167 

168def _process_complex(obj): 

169 return ComplexObject(obj.value * 2) 

170 

171 

172COMPLEX_DATA: List[ComplexObject] = [ComplexObject(i) for i in range(5)] 

173EXPECTED_COMPLEX = [ComplexObject(i * 2) for i in range(5)] 

174 

175 

176@pytest.mark.parametrize("parallel", [False, True]) 

177@pytest.mark.parametrize("pbar_type", [None, DEFAULT_PBAR_FN]) 

178def test_complex_objects(parallel, pbar_type): 

179 # override input_range with complex objects just for this test 

180 result = run_maybe_parallel( 

181 _process_complex, COMPLEX_DATA, parallel, pbar=pbar_type 

182 ) 

183 expected_complex = EXPECTED_COMPLEX 

184 assert all(a == b for a, b in zip(result, expected_complex)) 

185 

186 

187@dataset_decorator(["small"]) 

188def test_resource_cleanup(input_range, expected): 

189 initial_processes = len(multiprocessing.active_children()) 

190 run_maybe_parallel(square, input_range, True) 

191 time.sleep(0.05) 

192 final_processes = len(multiprocessing.active_children()) 

193 assert abs(final_processes - initial_processes) <= 2 

194 

195 

196@dataset_decorator(["small"]) 

197def test_custom_progress_bar(input_range, expected): 

198 def custom_progress_bar_fn(iterable: Iterable, **kwargs: Any) -> Iterable: 

199 return iterable 

200 

201 result = run_maybe_parallel(square, input_range, False, pbar=custom_progress_bar_fn) 

202 assert result == expected 

203 

204 

205@dataset_decorator(["small"]) 

206@pytest.mark.parametrize( 

207 "kwargs", 

208 [ 

209 # basic 

210 None, 

211 dict(), 

212 dict(desc="Processing"), 

213 dict(disable=True), 

214 dict(ascii=True), 

215 dict(config="default"), 

216 dict(config="bar"), 

217 dict(ascii=True, config="bar"), 

218 dict(message="Processing"), 

219 dict(message="Processing", desc="Processing"), 

220 # Progress bar selection 

221 dict(pbar="tqdm"), 

222 dict(pbar="spinner"), 

223 dict(pbar="none"), 

224 dict(pbar=None), 

225 # Display control 

226 dict(disable=True), 

227 dict(disable=False), 

228 # Message variations 

229 dict(desc="Processing items"), 

230 dict(message="Processing items"), # for spinner 

231 dict(desc="Processing items", message="Processing items"), # tests precedence 

232 # Format customization 

233 dict(ascii=True), # for tqdm 

234 dict(ascii=True, config="default"), # combined spinner+tqdm settings 

235 # Spinner configs 

236 dict(config="default"), 

237 dict(config="dots"), 

238 dict(config="bar"), 

239 dict(config="braille"), 

240 dict(config=["1", "2", "3", "4"]), # custom spinner sequence 

241 # Format strings 

242 dict(format_string="\r{spinner} {elapsed_time:.1f}s {message}"), 

243 dict(format_string="\r[{spinner}] {message} ({elapsed_time:.0f}s)"), 

244 # Combined settings 

245 dict( 

246 desc="Processing", 

247 ascii=True, 

248 config="bar", 

249 disable=False, 

250 format_string="\r{spinner} {message}", 

251 ), 

252 # Update behavior 

253 dict(format_string_when_updated=True), 

254 dict(format_string_when_updated="\r{spinner} UPDATED: {message}\n"), 

255 # Output stream control 

256 dict(output_stream=sys.stdout), 

257 dict(output_stream=sys.stderr), 

258 # Edge cases 

259 dict(update_interval=0.01), # fast updates 

260 dict(update_interval=1.0), # slow updates 

261 dict(initial_value="Starting..."), 

262 # Special format markers 

263 dict(leave=True), # tqdm specific 

264 dict(leave=False), # tqdm specific 

265 dict(dynamic_ncols=True), # tqdm specific 

266 # Multiple config overrides 

267 dict( 

268 pbar="spinner", 

269 config="dots", 

270 format_string="\r{spinner} {message}", 

271 message="Processing", 

272 update_interval=0.1, 

273 format_string_when_updated=True, 

274 ), 

275 ], 

276) 

277def test_progress_bar_kwargs(input_range, expected, kwargs): 

278 result = run_maybe_parallel(square, input_range, False, pbar_kwargs=kwargs) 

279 assert result == expected 

280 

281 

282@dataset_decorator(["medium"]) 

283def test_parallel_performance(input_range, expected): 

284 serial_result = run_maybe_parallel(slow_square, input_range, False) 

285 parallel_result = run_maybe_parallel(slow_square, input_range, True) 

286 assert serial_result == parallel_result 

287 

288 

289@dataset_decorator(["small"]) 

290def test_reject_pbar_str_when_not_str_or_callable(input_range, expected): 

291 with pytest.raises(TypeError): 

292 run_maybe_parallel(square, input_range, False, pbar=12345) # type: ignore[arg-type] 

293 

294 

295def custom_pbar(iterable: Iterable, **kwargs: Any) -> List: 

296 return list(iterable) 

297 

298 

299@dataset_decorator(["small"]) 

300def test_manual_callable_pbar(input_range, expected): 

301 results = run_maybe_parallel(square, input_range, False, pbar=custom_pbar) 

302 assert results == expected, "Manual callable pbar test failed." 

303 

304 

305@pytest.mark.parametrize( 

306 "input_data, parallel", 

307 [ 

308 (range(multiprocessing.cpu_count() + 1), True), 

309 (range(multiprocessing.cpu_count() - 1), True), 

310 ], 

311) 

312def test_edge_cases(input_data, parallel): 

313 result = run_maybe_parallel(square, input_data, parallel) 

314 assert result == [square(x) for x in input_data]