Coverage for muutils/parallel.py: 93%
92 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1import multiprocessing
2import functools
3from typing import (
4 Any,
5 Callable,
6 Iterable,
7 Literal,
8 Optional,
9 Tuple,
10 TypeVar,
11 Dict,
12 List,
13 Union,
14 Protocol,
15)
17# for no tqdm fallback
18from muutils.spinner import SpinnerContext
19from muutils.validate_type import get_fn_allowed_kwargs
22InputType = TypeVar("InputType")
23OutputType = TypeVar("OutputType")
24# typevars for our iterable and map
27class ProgressBarFunction(Protocol):
28 "a protocol for a progress bar function"
30 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
33ProgressBarOption = Literal["tqdm", "spinner", "none", None]
34# type for the progress bar option
37DEFAULT_PBAR_FN: ProgressBarOption
38# default progress bar function
40try:
41 # use tqdm if it's available
42 import tqdm # type: ignore[import-untyped]
44 DEFAULT_PBAR_FN = "tqdm"
46except ImportError:
47 # use progress bar as fallback
48 DEFAULT_PBAR_FN = "spinner"
51def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
52 "spinner wrapper"
53 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
54 SpinnerContext.__init__
55 )
56 mapped_kwargs: dict = {
57 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
58 }
59 if "desc" in kwargs and "message" not in mapped_kwargs:
60 mapped_kwargs["message"] = kwargs["desc"]
62 if "message" not in mapped_kwargs and "total" in kwargs:
63 mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
65 with SpinnerContext(**mapped_kwargs):
66 output = list(x)
68 return output
71def map_kwargs_for_tqdm(kwargs: dict) -> dict:
72 "map kwargs for tqdm, cant wrap because the pbar dissapears?"
73 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
74 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
76 if "desc" not in kwargs:
77 if "message" in kwargs:
78 mapped_kwargs["desc"] = kwargs["message"]
80 elif "total" in kwargs:
81 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
82 return mapped_kwargs
85def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
86 "fallback to no progress bar"
87 return x
90def set_up_progress_bar_fn(
91 pbar: Union[ProgressBarFunction, ProgressBarOption],
92 pbar_kwargs: Optional[Dict[str, Any]] = None,
93 **extra_kwargs,
94) -> Tuple[ProgressBarFunction, dict]:
95 """set up the progress bar function and its kwargs
97 # Parameters:
98 - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
99 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
100 - `pbar_kwargs : Optional[Dict[str, Any]]`
101 kwargs passed to the progress bar function (default to `None`)
102 (defaults to `None`)
104 # Returns:
105 - `Tuple[ProgressBarFunction, dict]`
106 a tuple of the progress bar function and its kwargs
108 # Raises:
109 - `ValueError` : if `pbar` is not one of the valid options
110 """
111 pbar_fn: ProgressBarFunction
113 if pbar_kwargs is None:
114 pbar_kwargs = dict()
116 pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
118 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
119 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
120 pbar_fn = no_progress_fn_wrap # type: ignore[assignment]
122 # if `pbar` is a different string, figure out which progress bar to use
123 elif isinstance(pbar, str):
124 if pbar == "tqdm":
125 pbar_fn = tqdm.tqdm
126 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
127 elif pbar == "spinner":
128 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
129 pbar_kwargs = dict()
130 else:
131 raise ValueError(
132 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
133 )
134 else:
135 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
136 pbar_fn = pbar
138 return pbar_fn, pbar_kwargs
141# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes
142def run_maybe_parallel(
143 func: Callable[[InputType], OutputType],
144 iterable: Iterable[InputType],
145 parallel: Union[bool, int],
146 pbar_kwargs: Optional[Dict[str, Any]] = None,
147 chunksize: Optional[int] = None,
148 keep_ordered: bool = True,
149 use_multiprocess: bool = False,
150 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
151) -> List[OutputType]:
152 """a function to make it easier to sometimes parallelize an operation
154 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
155 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
156 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
158 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
160 # Parameters:
161 - `func : Callable[[InputType], OutputType]`
162 function passed to either `map` or `Pool.imap`
163 - `iterable : Iterable[InputType]`
164 iterable passed to either `map` or `Pool.imap`
165 - `parallel : bool | int`
166 whether to run in parallel, and how many processes to use
167 - `pbar_kwargs : Dict[str, Any]`
168 kwargs passed to the progress bar function
170 # Returns:
171 - `List[OutputType]`
172 a list of the output of `func` for each element in `iterable`
174 # Raises:
175 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
176 - `ValueError` : if `use_multiprocess=True` and `parallel=False`
177 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
178 """
180 # number of inputs in iterable
181 n_inputs: int = len(iterable) # type: ignore[arg-type]
182 if n_inputs == 0:
183 # Return immediately if there is no input
184 return list()
186 # which progress bar to use
187 pbar_fn: ProgressBarFunction
188 pbar_kwargs_processed: dict
189 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
190 pbar=pbar,
191 pbar_kwargs=pbar_kwargs,
192 # extra kwargs
193 total=n_inputs,
194 )
196 # number of processes
197 num_processes: int
198 if isinstance(parallel, bool):
199 num_processes = multiprocessing.cpu_count() if parallel else 1
200 elif isinstance(parallel, int):
201 if parallel < 2:
202 raise ValueError(
203 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
204 )
205 num_processes = parallel
206 else:
207 raise ValueError(
208 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
209 )
211 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
212 num_processes = min(num_processes, n_inputs)
213 mp = multiprocessing
214 if num_processes == 1:
215 parallel = False
217 if use_multiprocess:
218 if not parallel:
219 raise ValueError("`use_multiprocess=True` requires `parallel=True`")
221 try:
222 import multiprocess # type: ignore[import-untyped]
223 except ImportError as e:
224 raise ImportError(
225 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
226 ) from e
228 mp = multiprocess
230 # set up the map function -- maybe its parallel, maybe it's just `map`
231 do_map: Callable[
232 [Callable[[InputType], OutputType], Iterable[InputType]],
233 Iterable[OutputType],
234 ]
235 if parallel:
236 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
237 pool = mp.Pool(num_processes)
239 # use `imap` if we want to keep the order, otherwise use `imap_unordered`
240 if keep_ordered:
241 do_map = pool.imap
242 else:
243 do_map = pool.imap_unordered
245 # figure out a smart chunksize if one is not given
246 chunksize_int: int
247 if chunksize is None:
248 chunksize_int = max(1, n_inputs // num_processes)
249 else:
250 chunksize_int = chunksize
252 # set the chunksize
253 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore
255 else:
256 do_map = map
258 # run the map function with a progress bar
259 output: List[OutputType] = list(
260 pbar_fn(
261 do_map(
262 func,
263 iterable,
264 ),
265 **pbar_kwargs_processed,
266 )
267 )
269 # close the pool if we used one
270 if parallel:
271 pool.close()
272 pool.join()
274 # return the output as a list
275 return output