Coverage for tests/unit/test_parallel.py: 96%
132 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1import sys
2import pytest
3import multiprocessing
4import time
5from typing import Any, List, Iterable
7# Import the function to test
8from muutils.parallel import DEFAULT_PBAR_FN, run_maybe_parallel
10DATA: dict = dict(
11 empty=[],
12 single=[5],
13 small=list(range(4)),
14 medium=list(range(10)),
15 large=list(range(50)),
16)
17SQUARE_RESULTS: dict = {k: [x**2 for x in v] for k, v in DATA.items()}
18ADD_ONE_RESULTS: dict = {k: [x + 1 for x in v] for k, v in DATA.items()}
21# Basic test functions
22def square(x: int) -> int:
23 return x**2
26def add_one(x: int) -> int:
27 return x + 1
30def raise_value_error(x: int) -> int:
31 if x == 5:
32 raise ValueError("Test error")
33 return x**2
36def slow_square(x: int) -> int:
37 time.sleep(0.0001)
38 return x**2
41def raise_on_negative(x: int) -> int:
42 if x < 0:
43 raise ValueError("Negative number")
44 return x
47def stateful_fn(x: list) -> list:
48 x.append(1)
49 return x
52class ComplexObject:
53 def __init__(self, value: int):
54 self.value = value
56 def __eq__(self, other: Any) -> bool:
57 return isinstance(other, ComplexObject) and self.value == other.value
60def dataset_decorator(keys: List[str]):
61 def wrapper(test_func):
62 return pytest.mark.parametrize(
63 "input_range, expected",
64 [(DATA[k], SQUARE_RESULTS[k]) for k in keys],
65 ids=keys,
66 )(test_func)
68 return wrapper
71@dataset_decorator(["empty", "single", "small"])
72@pytest.mark.parametrize("parallel", [False, True, 2, 4])
73@pytest.mark.parametrize("keep_ordered", [True, False])
74@pytest.mark.parametrize("use_multiprocess", [True, False])
75def test_general_functionality(
76 input_range, expected, parallel, keep_ordered, use_multiprocess
77):
78 # whether it's possible to use multiprocess
79 if use_multiprocess and (
80 parallel is False or parallel == 1 or len(input_range) == 1
81 ):
82 return
84 # run the function
85 results = run_maybe_parallel(
86 func=square,
87 iterable=input_range,
88 parallel=parallel,
89 pbar_kwargs={},
90 keep_ordered=keep_ordered,
91 use_multiprocess=use_multiprocess,
92 )
94 # check the results
95 assert set(results) == set(expected)
96 if keep_ordered:
97 assert results == expected
100@dataset_decorator(["small"])
101@pytest.mark.parametrize(
102 "pbar_type",
103 ["tqdm", "spinner", "none", None, "invalid"],
104)
105@pytest.mark.parametrize("disable_flag", [True, False])
106def test_progress_bar_types_and_disable(input_range, expected, pbar_type, disable_flag):
107 pbar_kwargs = {"disable": disable_flag}
108 if pbar_type == "invalid" and not disable_flag:
109 with pytest.raises(ValueError):
110 run_maybe_parallel(square, input_range, False, pbar_kwargs, pbar=pbar_type)
111 else:
112 results = run_maybe_parallel(
113 square, input_range, False, pbar_kwargs, pbar=pbar_type
114 )
115 assert results == expected
118@dataset_decorator(["small"])
119@pytest.mark.parametrize("chunksize", [None, 1, 5])
120@pytest.mark.parametrize("parallel", [False, True, 2])
121def test_chunksize_and_parallel(input_range, expected, chunksize, parallel):
122 results = run_maybe_parallel(square, input_range, parallel, {}, chunksize=chunksize)
123 assert results == expected
126@dataset_decorator(["small"])
127@pytest.mark.parametrize("invalid_parallel", ["invalid", 0, -1, 1.5])
128def test_invalid_parallel_values(input_range, expected, invalid_parallel):
129 with pytest.raises(ValueError):
130 run_maybe_parallel(square, input_range, invalid_parallel)
133def test_exception_in_func():
134 # one of the inputs is 0..3, no error here
135 # Let's inject a known error
136 error_input = [5] # Will raise ValueError
137 with pytest.raises(ValueError):
138 run_maybe_parallel(raise_value_error, error_input, True, {})
141@dataset_decorator(["small"])
142@pytest.mark.parametrize(
143 "iterable_factory",
144 [
145 lambda x: list(x),
146 lambda x: tuple(x),
147 lambda x: set(x),
148 lambda x: dict.fromkeys(x, 0),
149 ],
150)
151def test_different_iterables(input_range, expected, iterable_factory):
152 test_input = iterable_factory(input_range)
153 result = run_maybe_parallel(square, test_input, False)
154 if isinstance(test_input, set):
155 assert set(result) == set(expected)
156 else:
157 assert result == expected
160@pytest.mark.parametrize("parallel", [False, True])
161def test_error_handling(parallel):
162 # input_range is all positive small range, let's modify it to include negatives
163 input_data = [-1, 0, 1, -2]
164 with pytest.raises(ValueError):
165 run_maybe_parallel(raise_on_negative, input_data, parallel)
168def _process_complex(obj):
169 return ComplexObject(obj.value * 2)
172COMPLEX_DATA: List[ComplexObject] = [ComplexObject(i) for i in range(5)]
173EXPECTED_COMPLEX = [ComplexObject(i * 2) for i in range(5)]
176@pytest.mark.parametrize("parallel", [False, True])
177@pytest.mark.parametrize("pbar_type", [None, DEFAULT_PBAR_FN])
178def test_complex_objects(parallel, pbar_type):
179 # override input_range with complex objects just for this test
180 result = run_maybe_parallel(
181 _process_complex, COMPLEX_DATA, parallel, pbar=pbar_type
182 )
183 expected_complex = EXPECTED_COMPLEX
184 assert all(a == b for a, b in zip(result, expected_complex))
187@dataset_decorator(["small"])
188def test_resource_cleanup(input_range, expected):
189 initial_processes = len(multiprocessing.active_children())
190 run_maybe_parallel(square, input_range, True)
191 time.sleep(0.05)
192 final_processes = len(multiprocessing.active_children())
193 assert abs(final_processes - initial_processes) <= 2
196@dataset_decorator(["small"])
197def test_custom_progress_bar(input_range, expected):
198 def custom_progress_bar_fn(iterable: Iterable, **kwargs: Any) -> Iterable:
199 return iterable
201 result = run_maybe_parallel(square, input_range, False, pbar=custom_progress_bar_fn)
202 assert result == expected
205@dataset_decorator(["small"])
206@pytest.mark.parametrize(
207 "kwargs",
208 [
209 # basic
210 None,
211 dict(),
212 dict(desc="Processing"),
213 dict(disable=True),
214 dict(ascii=True),
215 dict(config="default"),
216 dict(config="bar"),
217 dict(ascii=True, config="bar"),
218 dict(message="Processing"),
219 dict(message="Processing", desc="Processing"),
220 # Progress bar selection
221 dict(pbar="tqdm"),
222 dict(pbar="spinner"),
223 dict(pbar="none"),
224 dict(pbar=None),
225 # Display control
226 dict(disable=True),
227 dict(disable=False),
228 # Message variations
229 dict(desc="Processing items"),
230 dict(message="Processing items"), # for spinner
231 dict(desc="Processing items", message="Processing items"), # tests precedence
232 # Format customization
233 dict(ascii=True), # for tqdm
234 dict(ascii=True, config="default"), # combined spinner+tqdm settings
235 # Spinner configs
236 dict(config="default"),
237 dict(config="dots"),
238 dict(config="bar"),
239 dict(config="braille"),
240 dict(config=["1", "2", "3", "4"]), # custom spinner sequence
241 # Format strings
242 dict(format_string="\r{spinner} {elapsed_time:.1f}s {message}"),
243 dict(format_string="\r[{spinner}] {message} ({elapsed_time:.0f}s)"),
244 # Combined settings
245 dict(
246 desc="Processing",
247 ascii=True,
248 config="bar",
249 disable=False,
250 format_string="\r{spinner} {message}",
251 ),
252 # Update behavior
253 dict(format_string_when_updated=True),
254 dict(format_string_when_updated="\r{spinner} UPDATED: {message}\n"),
255 # Output stream control
256 dict(output_stream=sys.stdout),
257 dict(output_stream=sys.stderr),
258 # Edge cases
259 dict(update_interval=0.01), # fast updates
260 dict(update_interval=1.0), # slow updates
261 dict(initial_value="Starting..."),
262 # Special format markers
263 dict(leave=True), # tqdm specific
264 dict(leave=False), # tqdm specific
265 dict(dynamic_ncols=True), # tqdm specific
266 # Multiple config overrides
267 dict(
268 pbar="spinner",
269 config="dots",
270 format_string="\r{spinner} {message}",
271 message="Processing",
272 update_interval=0.1,
273 format_string_when_updated=True,
274 ),
275 ],
276)
277def test_progress_bar_kwargs(input_range, expected, kwargs):
278 result = run_maybe_parallel(square, input_range, False, pbar_kwargs=kwargs)
279 assert result == expected
282@dataset_decorator(["medium"])
283def test_parallel_performance(input_range, expected):
284 serial_result = run_maybe_parallel(slow_square, input_range, False)
285 parallel_result = run_maybe_parallel(slow_square, input_range, True)
286 assert serial_result == parallel_result
289@dataset_decorator(["small"])
290def test_reject_pbar_str_when_not_str_or_callable(input_range, expected):
291 with pytest.raises(TypeError):
292 run_maybe_parallel(square, input_range, False, pbar=12345) # type: ignore[arg-type]
295def custom_pbar(iterable: Iterable, **kwargs: Any) -> List:
296 return list(iterable)
299@dataset_decorator(["small"])
300def test_manual_callable_pbar(input_range, expected):
301 results = run_maybe_parallel(square, input_range, False, pbar=custom_pbar)
302 assert results == expected, "Manual callable pbar test failed."
305@pytest.mark.parametrize(
306 "input_data, parallel",
307 [
308 (range(multiprocessing.cpu_count() + 1), True),
309 (range(multiprocessing.cpu_count() - 1), True),
310 ],
311)
312def test_edge_cases(input_data, parallel):
313 result = run_maybe_parallel(square, input_data, parallel)
314 assert result == [square(x) for x in input_data]