Coverage for tests / unit / test_parallel.py: 96%
132 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1import sys
2import pytest
3import multiprocessing
4import time
5from typing import Any, List, Iterable
7# Import the function to test
8from muutils.parallel import DEFAULT_PBAR_FN, run_maybe_parallel
10DATA: dict = dict(
11 empty=[],
12 single=[5],
13 small=list(range(4)),
14 medium=list(range(10)),
15 large=list(range(50)),
16)
17SQUARE_RESULTS: dict = {k: [x**2 for x in v] for k, v in DATA.items()}
18ADD_ONE_RESULTS: dict = {k: [x + 1 for x in v] for k, v in DATA.items()}
21# Basic test functions
22def square(x: int) -> int:
23 return x**2
26def add_one(x: int) -> int:
27 return x + 1
30def raise_value_error(x: int) -> int:
31 if x == 5:
32 raise ValueError("Test error")
33 return x**2
36def slow_square(x: int) -> int:
37 time.sleep(0.0001)
38 return x**2
41def raise_on_negative(x: int) -> int:
42 if x < 0:
43 raise ValueError("Negative number")
44 return x
47def stateful_fn(x: list) -> list:
48 x.append(1)
49 return x
52class ComplexObject:
53 def __init__(self, value: int):
54 self.value = value
56 def __eq__(self, other: Any) -> bool:
57 return isinstance(other, ComplexObject) and self.value == other.value
60def dataset_decorator(keys: List[str]):
61 def wrapper(test_func):
62 return pytest.mark.parametrize(
63 "input_range, expected",
64 [(DATA[k], SQUARE_RESULTS[k]) for k in keys],
65 ids=keys,
66 )(test_func)
68 return wrapper
71@dataset_decorator(["empty", "single", "small"])
72@pytest.mark.parametrize("parallel", [False, True, 2, 4])
73@pytest.mark.parametrize("keep_ordered", [True, False])
74@pytest.mark.parametrize("use_multiprocess", [True, False])
75def test_general_functionality(
76 input_range, expected, parallel, keep_ordered, use_multiprocess
77):
78 # whether it's possible to use multiprocess
79 if use_multiprocess and (
80 parallel is False or parallel == 1 or len(input_range) == 1
81 ):
82 return
84 # run the function
85 results = run_maybe_parallel(
86 func=square,
87 iterable=input_range,
88 parallel=parallel,
89 pbar_kwargs={},
90 keep_ordered=keep_ordered,
91 use_multiprocess=use_multiprocess,
92 )
94 # check the results
95 assert set(results) == set(expected)
96 if keep_ordered:
97 assert results == expected
100@dataset_decorator(["small"])
101@pytest.mark.parametrize(
102 "pbar_type",
103 ["tqdm", "spinner", "none", None, "invalid"],
104)
105@pytest.mark.parametrize("disable_flag", [True, False])
106def test_progress_bar_types_and_disable(input_range, expected, pbar_type, disable_flag):
107 pbar_kwargs = {"disable": disable_flag}
108 if pbar_type == "invalid" and not disable_flag:
109 with pytest.raises(ValueError):
110 run_maybe_parallel(square, input_range, False, pbar_kwargs, pbar=pbar_type) # pyright: ignore[reportArgumentType]
111 else:
112 results = run_maybe_parallel(
113 square,
114 input_range,
115 False,
116 pbar_kwargs,
117 pbar=pbar_type, # pyright: ignore[reportArgumentType]
118 )
119 assert results == expected
122@dataset_decorator(["small"])
123@pytest.mark.parametrize("chunksize", [None, 1, 5])
124@pytest.mark.parametrize("parallel", [False, True, 2])
125def test_chunksize_and_parallel(input_range, expected, chunksize, parallel):
126 results = run_maybe_parallel(square, input_range, parallel, {}, chunksize=chunksize)
127 assert results == expected
130@dataset_decorator(["small"])
131@pytest.mark.parametrize("invalid_parallel", ["invalid", 0, -1, 1.5])
132def test_invalid_parallel_values(input_range, expected, invalid_parallel):
133 with pytest.raises(ValueError):
134 run_maybe_parallel(square, input_range, invalid_parallel)
137def test_exception_in_func():
138 # one of the inputs is 0..3, no error here
139 # Let's inject a known error
140 error_input = [5] # Will raise ValueError
141 with pytest.raises(ValueError):
142 run_maybe_parallel(raise_value_error, error_input, True, {})
145@dataset_decorator(["small"])
146@pytest.mark.parametrize(
147 "iterable_factory",
148 [
149 lambda x: list(x),
150 lambda x: tuple(x),
151 lambda x: set(x),
152 lambda x: dict.fromkeys(x, 0),
153 ],
154)
155def test_different_iterables(input_range, expected, iterable_factory):
156 test_input = iterable_factory(input_range)
157 result = run_maybe_parallel(square, test_input, False)
158 if isinstance(test_input, set):
159 assert set(result) == set(expected)
160 else:
161 assert result == expected
164@pytest.mark.parametrize("parallel", [False, True])
165def test_error_handling(parallel):
166 # input_range is all positive small range, let's modify it to include negatives
167 input_data = [-1, 0, 1, -2]
168 with pytest.raises(ValueError):
169 run_maybe_parallel(raise_on_negative, input_data, parallel)
172def _process_complex(obj):
173 return ComplexObject(obj.value * 2)
176COMPLEX_DATA: List[ComplexObject] = [ComplexObject(i) for i in range(5)]
177EXPECTED_COMPLEX = [ComplexObject(i * 2) for i in range(5)]
180@pytest.mark.parametrize("parallel", [False, True])
181@pytest.mark.parametrize("pbar_type", [None, DEFAULT_PBAR_FN])
182def test_complex_objects(parallel, pbar_type):
183 # override input_range with complex objects just for this test
184 result = run_maybe_parallel(
185 _process_complex, COMPLEX_DATA, parallel, pbar=pbar_type
186 )
187 expected_complex = EXPECTED_COMPLEX
188 assert all(a == b for a, b in zip(result, expected_complex))
191@dataset_decorator(["small"])
192def test_resource_cleanup(input_range, expected):
193 initial_processes = len(multiprocessing.active_children())
194 run_maybe_parallel(square, input_range, True)
195 time.sleep(0.05)
196 final_processes = len(multiprocessing.active_children())
197 assert abs(final_processes - initial_processes) <= 2
200@dataset_decorator(["small"])
201def test_custom_progress_bar(input_range, expected):
202 def custom_progress_bar_fn(iterable: Iterable, **kwargs: Any) -> Iterable:
203 return iterable
205 result = run_maybe_parallel(square, input_range, False, pbar=custom_progress_bar_fn)
206 assert result == expected
209@dataset_decorator(["small"])
210@pytest.mark.parametrize(
211 "kwargs",
212 [
213 # basic
214 None,
215 dict(),
216 dict(desc="Processing"),
217 dict(disable=True),
218 dict(ascii=True),
219 dict(config="default"),
220 dict(config="bar"),
221 dict(ascii=True, config="bar"),
222 dict(message="Processing"),
223 dict(message="Processing", desc="Processing"),
224 # Progress bar selection
225 dict(pbar="tqdm"),
226 dict(pbar="spinner"),
227 dict(pbar="none"),
228 dict(pbar=None),
229 # Display control
230 dict(disable=True),
231 dict(disable=False),
232 # Message variations
233 dict(desc="Processing items"),
234 dict(message="Processing items"), # for spinner
235 dict(desc="Processing items", message="Processing items"), # tests precedence
236 # Format customization
237 dict(ascii=True), # for tqdm
238 dict(ascii=True, config="default"), # combined spinner+tqdm settings
239 # Spinner configs
240 dict(config="default"),
241 dict(config="dots"),
242 dict(config="bar"),
243 dict(config="braille"),
244 dict(config=["1", "2", "3", "4"]), # custom spinner sequence
245 # Format strings
246 dict(format_string="\r{spinner} {elapsed_time:.1f}s {message}"),
247 dict(format_string="\r[{spinner}] {message} ({elapsed_time:.0f}s)"),
248 # Combined settings
249 dict(
250 desc="Processing",
251 ascii=True,
252 config="bar",
253 disable=False,
254 format_string="\r{spinner} {message}",
255 ),
256 # Update behavior
257 dict(format_string_when_updated=True),
258 dict(format_string_when_updated="\r{spinner} UPDATED: {message}\n"),
259 # Output stream control
260 dict(output_stream=sys.stdout),
261 dict(output_stream=sys.stderr),
262 # Edge cases
263 dict(update_interval=0.01), # fast updates
264 dict(update_interval=1.0), # slow updates
265 dict(initial_value="Starting..."),
266 # Special format markers
267 dict(leave=True), # tqdm specific
268 dict(leave=False), # tqdm specific
269 dict(dynamic_ncols=True), # tqdm specific
270 # Multiple config overrides
271 dict(
272 pbar="spinner",
273 config="dots",
274 format_string="\r{spinner} {message}",
275 message="Processing",
276 update_interval=0.1,
277 format_string_when_updated=True,
278 ),
279 ],
280)
281def test_progress_bar_kwargs(input_range, expected, kwargs):
282 result = run_maybe_parallel(square, input_range, False, pbar_kwargs=kwargs)
283 assert result == expected
286@dataset_decorator(["medium"])
287def test_parallel_performance(input_range, expected):
288 serial_result = run_maybe_parallel(slow_square, input_range, False)
289 parallel_result = run_maybe_parallel(slow_square, input_range, True)
290 assert serial_result == parallel_result
293@dataset_decorator(["small"])
294def test_reject_pbar_str_when_not_str_or_callable(input_range, expected):
295 with pytest.raises(TypeError):
296 run_maybe_parallel(square, input_range, False, pbar=12345) # type: ignore[arg-type]
299def custom_pbar(iterable: Iterable, **kwargs: Any) -> List:
300 return list(iterable)
303@dataset_decorator(["small"])
304def test_manual_callable_pbar(input_range, expected):
305 results = run_maybe_parallel(square, input_range, False, pbar=custom_pbar)
306 assert results == expected, "Manual callable pbar test failed."
309@pytest.mark.parametrize(
310 "input_data, parallel",
311 [
312 (range(multiprocessing.cpu_count() + 1), True),
313 (range(multiprocessing.cpu_count() - 1), True),
314 ],
315)
316def test_edge_cases(input_data, parallel):
317 result = run_maybe_parallel(square, input_data, parallel)
318 assert result == [square(x) for x in input_data]