Coverage for tests / unit / test_parallel.py: 96%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1import sys 

2import pytest 

3import multiprocessing 

4import time 

5from typing import Any, List, Iterable 

6 

7# Import the function to test 

8from muutils.parallel import DEFAULT_PBAR_FN, run_maybe_parallel 

9 

10DATA: dict = dict( 

11 empty=[], 

12 single=[5], 

13 small=list(range(4)), 

14 medium=list(range(10)), 

15 large=list(range(50)), 

16) 

17SQUARE_RESULTS: dict = {k: [x**2 for x in v] for k, v in DATA.items()} 

18ADD_ONE_RESULTS: dict = {k: [x + 1 for x in v] for k, v in DATA.items()} 

19 

20 

21# Basic test functions 

22def square(x: int) -> int: 

23 return x**2 

24 

25 

26def add_one(x: int) -> int: 

27 return x + 1 

28 

29 

30def raise_value_error(x: int) -> int: 

31 if x == 5: 

32 raise ValueError("Test error") 

33 return x**2 

34 

35 

36def slow_square(x: int) -> int: 

37 time.sleep(0.0001) 

38 return x**2 

39 

40 

41def raise_on_negative(x: int) -> int: 

42 if x < 0: 

43 raise ValueError("Negative number") 

44 return x 

45 

46 

47def stateful_fn(x: list) -> list: 

48 x.append(1) 

49 return x 

50 

51 

52class ComplexObject: 

53 def __init__(self, value: int): 

54 self.value = value 

55 

56 def __eq__(self, other: Any) -> bool: 

57 return isinstance(other, ComplexObject) and self.value == other.value 

58 

59 

60def dataset_decorator(keys: List[str]): 

61 def wrapper(test_func): 

62 return pytest.mark.parametrize( 

63 "input_range, expected", 

64 [(DATA[k], SQUARE_RESULTS[k]) for k in keys], 

65 ids=keys, 

66 )(test_func) 

67 

68 return wrapper 

69 

70 

71@dataset_decorator(["empty", "single", "small"]) 

72@pytest.mark.parametrize("parallel", [False, True, 2, 4]) 

73@pytest.mark.parametrize("keep_ordered", [True, False]) 

74@pytest.mark.parametrize("use_multiprocess", [True, False]) 

75def test_general_functionality( 

76 input_range, expected, parallel, keep_ordered, use_multiprocess 

77): 

78 # whether it's possible to use multiprocess 

79 if use_multiprocess and ( 

80 parallel is False or parallel == 1 or len(input_range) == 1 

81 ): 

82 return 

83 

84 # run the function 

85 results = run_maybe_parallel( 

86 func=square, 

87 iterable=input_range, 

88 parallel=parallel, 

89 pbar_kwargs={}, 

90 keep_ordered=keep_ordered, 

91 use_multiprocess=use_multiprocess, 

92 ) 

93 

94 # check the results 

95 assert set(results) == set(expected) 

96 if keep_ordered: 

97 assert results == expected 

98 

99 

100@dataset_decorator(["small"]) 

101@pytest.mark.parametrize( 

102 "pbar_type", 

103 ["tqdm", "spinner", "none", None, "invalid"], 

104) 

105@pytest.mark.parametrize("disable_flag", [True, False]) 

106def test_progress_bar_types_and_disable(input_range, expected, pbar_type, disable_flag): 

107 pbar_kwargs = {"disable": disable_flag} 

108 if pbar_type == "invalid" and not disable_flag: 

109 with pytest.raises(ValueError): 

110 run_maybe_parallel(square, input_range, False, pbar_kwargs, pbar=pbar_type) # pyright: ignore[reportArgumentType] 

111 else: 

112 results = run_maybe_parallel( 

113 square, 

114 input_range, 

115 False, 

116 pbar_kwargs, 

117 pbar=pbar_type, # pyright: ignore[reportArgumentType] 

118 ) 

119 assert results == expected 

120 

121 

122@dataset_decorator(["small"]) 

123@pytest.mark.parametrize("chunksize", [None, 1, 5]) 

124@pytest.mark.parametrize("parallel", [False, True, 2]) 

125def test_chunksize_and_parallel(input_range, expected, chunksize, parallel): 

126 results = run_maybe_parallel(square, input_range, parallel, {}, chunksize=chunksize) 

127 assert results == expected 

128 

129 

130@dataset_decorator(["small"]) 

131@pytest.mark.parametrize("invalid_parallel", ["invalid", 0, -1, 1.5]) 

132def test_invalid_parallel_values(input_range, expected, invalid_parallel): 

133 with pytest.raises(ValueError): 

134 run_maybe_parallel(square, input_range, invalid_parallel) 

135 

136 

137def test_exception_in_func(): 

138 # one of the inputs is 0..3, no error here 

139 # Let's inject a known error 

140 error_input = [5] # Will raise ValueError 

141 with pytest.raises(ValueError): 

142 run_maybe_parallel(raise_value_error, error_input, True, {}) 

143 

144 

145@dataset_decorator(["small"]) 

146@pytest.mark.parametrize( 

147 "iterable_factory", 

148 [ 

149 lambda x: list(x), 

150 lambda x: tuple(x), 

151 lambda x: set(x), 

152 lambda x: dict.fromkeys(x, 0), 

153 ], 

154) 

155def test_different_iterables(input_range, expected, iterable_factory): 

156 test_input = iterable_factory(input_range) 

157 result = run_maybe_parallel(square, test_input, False) 

158 if isinstance(test_input, set): 

159 assert set(result) == set(expected) 

160 else: 

161 assert result == expected 

162 

163 

164@pytest.mark.parametrize("parallel", [False, True]) 

165def test_error_handling(parallel): 

166 # input_range is all positive small range, let's modify it to include negatives 

167 input_data = [-1, 0, 1, -2] 

168 with pytest.raises(ValueError): 

169 run_maybe_parallel(raise_on_negative, input_data, parallel) 

170 

171 

172def _process_complex(obj): 

173 return ComplexObject(obj.value * 2) 

174 

175 

176COMPLEX_DATA: List[ComplexObject] = [ComplexObject(i) for i in range(5)] 

177EXPECTED_COMPLEX = [ComplexObject(i * 2) for i in range(5)] 

178 

179 

180@pytest.mark.parametrize("parallel", [False, True]) 

181@pytest.mark.parametrize("pbar_type", [None, DEFAULT_PBAR_FN]) 

182def test_complex_objects(parallel, pbar_type): 

183 # override input_range with complex objects just for this test 

184 result = run_maybe_parallel( 

185 _process_complex, COMPLEX_DATA, parallel, pbar=pbar_type 

186 ) 

187 expected_complex = EXPECTED_COMPLEX 

188 assert all(a == b for a, b in zip(result, expected_complex)) 

189 

190 

191@dataset_decorator(["small"]) 

192def test_resource_cleanup(input_range, expected): 

193 initial_processes = len(multiprocessing.active_children()) 

194 run_maybe_parallel(square, input_range, True) 

195 time.sleep(0.05) 

196 final_processes = len(multiprocessing.active_children()) 

197 assert abs(final_processes - initial_processes) <= 2 

198 

199 

200@dataset_decorator(["small"]) 

201def test_custom_progress_bar(input_range, expected): 

202 def custom_progress_bar_fn(iterable: Iterable, **kwargs: Any) -> Iterable: 

203 return iterable 

204 

205 result = run_maybe_parallel(square, input_range, False, pbar=custom_progress_bar_fn) 

206 assert result == expected 

207 

208 

209@dataset_decorator(["small"]) 

210@pytest.mark.parametrize( 

211 "kwargs", 

212 [ 

213 # basic 

214 None, 

215 dict(), 

216 dict(desc="Processing"), 

217 dict(disable=True), 

218 dict(ascii=True), 

219 dict(config="default"), 

220 dict(config="bar"), 

221 dict(ascii=True, config="bar"), 

222 dict(message="Processing"), 

223 dict(message="Processing", desc="Processing"), 

224 # Progress bar selection 

225 dict(pbar="tqdm"), 

226 dict(pbar="spinner"), 

227 dict(pbar="none"), 

228 dict(pbar=None), 

229 # Display control 

230 dict(disable=True), 

231 dict(disable=False), 

232 # Message variations 

233 dict(desc="Processing items"), 

234 dict(message="Processing items"), # for spinner 

235 dict(desc="Processing items", message="Processing items"), # tests precedence 

236 # Format customization 

237 dict(ascii=True), # for tqdm 

238 dict(ascii=True, config="default"), # combined spinner+tqdm settings 

239 # Spinner configs 

240 dict(config="default"), 

241 dict(config="dots"), 

242 dict(config="bar"), 

243 dict(config="braille"), 

244 dict(config=["1", "2", "3", "4"]), # custom spinner sequence 

245 # Format strings 

246 dict(format_string="\r{spinner} {elapsed_time:.1f}s {message}"), 

247 dict(format_string="\r[{spinner}] {message} ({elapsed_time:.0f}s)"), 

248 # Combined settings 

249 dict( 

250 desc="Processing", 

251 ascii=True, 

252 config="bar", 

253 disable=False, 

254 format_string="\r{spinner} {message}", 

255 ), 

256 # Update behavior 

257 dict(format_string_when_updated=True), 

258 dict(format_string_when_updated="\r{spinner} UPDATED: {message}\n"), 

259 # Output stream control 

260 dict(output_stream=sys.stdout), 

261 dict(output_stream=sys.stderr), 

262 # Edge cases 

263 dict(update_interval=0.01), # fast updates 

264 dict(update_interval=1.0), # slow updates 

265 dict(initial_value="Starting..."), 

266 # Special format markers 

267 dict(leave=True), # tqdm specific 

268 dict(leave=False), # tqdm specific 

269 dict(dynamic_ncols=True), # tqdm specific 

270 # Multiple config overrides 

271 dict( 

272 pbar="spinner", 

273 config="dots", 

274 format_string="\r{spinner} {message}", 

275 message="Processing", 

276 update_interval=0.1, 

277 format_string_when_updated=True, 

278 ), 

279 ], 

280) 

281def test_progress_bar_kwargs(input_range, expected, kwargs): 

282 result = run_maybe_parallel(square, input_range, False, pbar_kwargs=kwargs) 

283 assert result == expected 

284 

285 

286@dataset_decorator(["medium"]) 

287def test_parallel_performance(input_range, expected): 

288 serial_result = run_maybe_parallel(slow_square, input_range, False) 

289 parallel_result = run_maybe_parallel(slow_square, input_range, True) 

290 assert serial_result == parallel_result 

291 

292 

293@dataset_decorator(["small"]) 

294def test_reject_pbar_str_when_not_str_or_callable(input_range, expected): 

295 with pytest.raises(TypeError): 

296 run_maybe_parallel(square, input_range, False, pbar=12345) # type: ignore[arg-type] 

297 

298 

299def custom_pbar(iterable: Iterable, **kwargs: Any) -> List: 

300 return list(iterable) 

301 

302 

303@dataset_decorator(["small"]) 

304def test_manual_callable_pbar(input_range, expected): 

305 results = run_maybe_parallel(square, input_range, False, pbar=custom_pbar) 

306 assert results == expected, "Manual callable pbar test failed." 

307 

308 

309@pytest.mark.parametrize( 

310 "input_data, parallel", 

311 [ 

312 (range(multiprocessing.cpu_count() + 1), True), 

313 (range(multiprocessing.cpu_count() - 1), True), 

314 ], 

315) 

316def test_edge_cases(input_data, parallel): 

317 result = run_maybe_parallel(square, input_data, parallel) 

318 assert result == [square(x) for x in input_data]