Coverage for muutils/parallel.py: 94%
93 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-07 20:16 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-07 20:16 -0700
1"parallel processing utilities, chiefly `run_maybe_parallel`"
3from __future__ import annotations
5import multiprocessing
6import functools
7from typing import (
8 Any,
9 Callable,
10 Iterable,
11 Literal,
12 Optional,
13 Tuple,
14 TypeVar,
15 Dict,
16 List,
17 Union,
18 Protocol,
19)
21# for no tqdm fallback
22from muutils.spinner import SpinnerContext
23from muutils.validate_type import get_fn_allowed_kwargs
26InputType = TypeVar("InputType")
27OutputType = TypeVar("OutputType")
28# typevars for our iterable and map
31class ProgressBarFunction(Protocol):
32 "a protocol for a progress bar function"
34 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
37ProgressBarOption = Literal["tqdm", "spinner", "none", None]
38# type for the progress bar option
41DEFAULT_PBAR_FN: ProgressBarOption
42# default progress bar function
44try:
45 # use tqdm if it's available
46 import tqdm # type: ignore[import-untyped]
48 DEFAULT_PBAR_FN = "tqdm"
50except ImportError:
51 # use progress bar as fallback
52 DEFAULT_PBAR_FN = "spinner"
55def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
56 "spinner wrapper"
57 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
58 SpinnerContext.__init__
59 )
60 mapped_kwargs: dict = {
61 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
62 }
63 if "desc" in kwargs and "message" not in mapped_kwargs:
64 mapped_kwargs["message"] = kwargs["desc"]
66 if "message" not in mapped_kwargs and "total" in kwargs:
67 mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
69 with SpinnerContext(**mapped_kwargs):
70 output = list(x)
72 return output
75def map_kwargs_for_tqdm(kwargs: dict) -> dict:
76 "map kwargs for tqdm, cant wrap because the pbar dissapears?"
77 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
78 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
80 if "desc" not in kwargs:
81 if "message" in kwargs:
82 mapped_kwargs["desc"] = kwargs["message"]
84 elif "total" in kwargs:
85 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
86 return mapped_kwargs
89def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
90 "fallback to no progress bar"
91 return x
94def set_up_progress_bar_fn(
95 pbar: Union[ProgressBarFunction, ProgressBarOption],
96 pbar_kwargs: Optional[Dict[str, Any]] = None,
97 **extra_kwargs,
98) -> Tuple[ProgressBarFunction, dict]:
99 """set up the progress bar function and its kwargs
101 # Parameters:
102 - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
103 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
104 - `pbar_kwargs : Optional[Dict[str, Any]]`
105 kwargs passed to the progress bar function (default to `None`)
106 (defaults to `None`)
108 # Returns:
109 - `Tuple[ProgressBarFunction, dict]`
110 a tuple of the progress bar function and its kwargs
112 # Raises:
113 - `ValueError` : if `pbar` is not one of the valid options
114 """
115 pbar_fn: ProgressBarFunction
117 if pbar_kwargs is None:
118 pbar_kwargs = dict()
120 pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
122 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
123 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
124 pbar_fn = no_progress_fn_wrap # type: ignore[assignment]
126 # if `pbar` is a different string, figure out which progress bar to use
127 elif isinstance(pbar, str):
128 if pbar == "tqdm":
129 pbar_fn = tqdm.tqdm
130 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
131 elif pbar == "spinner":
132 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
133 pbar_kwargs = dict()
134 else:
135 raise ValueError(
136 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
137 )
138 else:
139 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
140 pbar_fn = pbar
142 return pbar_fn, pbar_kwargs
145# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes
146def run_maybe_parallel(
147 func: Callable[[InputType], OutputType],
148 iterable: Iterable[InputType],
149 parallel: Union[bool, int],
150 pbar_kwargs: Optional[Dict[str, Any]] = None,
151 chunksize: Optional[int] = None,
152 keep_ordered: bool = True,
153 use_multiprocess: bool = False,
154 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
155) -> List[OutputType]:
156 """a function to make it easier to sometimes parallelize an operation
158 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
159 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
160 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
162 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
164 # Parameters:
165 - `func : Callable[[InputType], OutputType]`
166 function passed to either `map` or `Pool.imap`
167 - `iterable : Iterable[InputType]`
168 iterable passed to either `map` or `Pool.imap`
169 - `parallel : bool | int`
170 whether to run in parallel, and how many processes to use
171 - `pbar_kwargs : Dict[str, Any]`
172 kwargs passed to the progress bar function
174 # Returns:
175 - `List[OutputType]`
176 a list of the output of `func` for each element in `iterable`
178 # Raises:
179 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
180 - `ValueError` : if `use_multiprocess=True` and `parallel=False`
181 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
182 """
184 # number of inputs in iterable
185 n_inputs: int = len(iterable) # type: ignore[arg-type]
186 if n_inputs == 0:
187 # Return immediately if there is no input
188 return list()
190 # which progress bar to use
191 pbar_fn: ProgressBarFunction
192 pbar_kwargs_processed: dict
193 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
194 pbar=pbar,
195 pbar_kwargs=pbar_kwargs,
196 # extra kwargs
197 total=n_inputs,
198 )
200 # number of processes
201 num_processes: int
202 if isinstance(parallel, bool):
203 num_processes = multiprocessing.cpu_count() if parallel else 1
204 elif isinstance(parallel, int):
205 if parallel < 2:
206 raise ValueError(
207 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
208 )
209 num_processes = parallel
210 else:
211 raise ValueError(
212 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
213 )
215 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
216 num_processes = min(num_processes, n_inputs)
217 mp = multiprocessing
218 if num_processes == 1:
219 parallel = False
221 if use_multiprocess:
222 if not parallel:
223 raise ValueError("`use_multiprocess=True` requires `parallel=True`")
225 try:
226 import multiprocess # type: ignore[import-untyped]
227 except ImportError as e:
228 raise ImportError(
229 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
230 ) from e
232 mp = multiprocess
234 # set up the map function -- maybe its parallel, maybe it's just `map`
235 do_map: Callable[
236 [Callable[[InputType], OutputType], Iterable[InputType]],
237 Iterable[OutputType],
238 ]
239 if parallel:
240 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
241 pool = mp.Pool(num_processes)
243 # use `imap` if we want to keep the order, otherwise use `imap_unordered`
244 if keep_ordered:
245 do_map = pool.imap
246 else:
247 do_map = pool.imap_unordered
249 # figure out a smart chunksize if one is not given
250 chunksize_int: int
251 if chunksize is None:
252 chunksize_int = max(1, n_inputs // num_processes)
253 else:
254 chunksize_int = chunksize
256 # set the chunksize
257 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore
259 else:
260 do_map = map
262 # run the map function with a progress bar
263 output: List[OutputType] = list(
264 pbar_fn(
265 do_map(
266 func,
267 iterable,
268 ),
269 **pbar_kwargs_processed,
270 )
271 )
273 # close the pool if we used one
274 if parallel:
275 pool.close()
276 pool.join()
278 # return the output as a list
279 return output