Coverage for muutils/parallel.py: 94%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-07 20:16 -0700

1"parallel processing utilities, chiefly `run_maybe_parallel`" 

2 

3from __future__ import annotations 

4 

5import multiprocessing 

6import functools 

7from typing import ( 

8 Any, 

9 Callable, 

10 Iterable, 

11 Literal, 

12 Optional, 

13 Tuple, 

14 TypeVar, 

15 Dict, 

16 List, 

17 Union, 

18 Protocol, 

19) 

20 

21# for no tqdm fallback 

22from muutils.spinner import SpinnerContext 

23from muutils.validate_type import get_fn_allowed_kwargs 

24 

25 

26InputType = TypeVar("InputType") 

27OutputType = TypeVar("OutputType") 

28# typevars for our iterable and map 

29 

30 

31class ProgressBarFunction(Protocol): 

32 "a protocol for a progress bar function" 

33 

34 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 

35 

36 

37ProgressBarOption = Literal["tqdm", "spinner", "none", None] 

38# type for the progress bar option 

39 

40 

41DEFAULT_PBAR_FN: ProgressBarOption 

42# default progress bar function 

43 

44try: 

45 # use tqdm if it's available 

46 import tqdm # type: ignore[import-untyped] 

47 

48 DEFAULT_PBAR_FN = "tqdm" 

49 

50except ImportError: 

51 # use progress bar as fallback 

52 DEFAULT_PBAR_FN = "spinner" 

53 

54 

55def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 

56 "spinner wrapper" 

57 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 

58 SpinnerContext.__init__ 

59 ) 

60 mapped_kwargs: dict = { 

61 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 

62 } 

63 if "desc" in kwargs and "message" not in mapped_kwargs: 

64 mapped_kwargs["message"] = kwargs["desc"] 

65 

66 if "message" not in mapped_kwargs and "total" in kwargs: 

67 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 

68 

69 with SpinnerContext(**mapped_kwargs): 

70 output = list(x) 

71 

72 return output 

73 

74 

75def map_kwargs_for_tqdm(kwargs: dict) -> dict: 

76 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 

77 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 

78 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 

79 

80 if "desc" not in kwargs: 

81 if "message" in kwargs: 

82 mapped_kwargs["desc"] = kwargs["message"] 

83 

84 elif "total" in kwargs: 

85 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 

86 return mapped_kwargs 

87 

88 

89def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 

90 "fallback to no progress bar" 

91 return x 

92 

93 

94def set_up_progress_bar_fn( 

95 pbar: Union[ProgressBarFunction, ProgressBarOption], 

96 pbar_kwargs: Optional[Dict[str, Any]] = None, 

97 **extra_kwargs, 

98) -> Tuple[ProgressBarFunction, dict]: 

99 """set up the progress bar function and its kwargs 

100 

101 # Parameters: 

102 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 

103 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 

104 - `pbar_kwargs : Optional[Dict[str, Any]]` 

105 kwargs passed to the progress bar function (default to `None`) 

106 (defaults to `None`) 

107 

108 # Returns: 

109 - `Tuple[ProgressBarFunction, dict]` 

110 a tuple of the progress bar function and its kwargs 

111 

112 # Raises: 

113 - `ValueError` : if `pbar` is not one of the valid options 

114 """ 

115 pbar_fn: ProgressBarFunction 

116 

117 if pbar_kwargs is None: 

118 pbar_kwargs = dict() 

119 

120 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 

121 

122 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 

123 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 

124 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 

125 

126 # if `pbar` is a different string, figure out which progress bar to use 

127 elif isinstance(pbar, str): 

128 if pbar == "tqdm": 

129 pbar_fn = tqdm.tqdm 

130 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 

131 elif pbar == "spinner": 

132 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 

133 pbar_kwargs = dict() 

134 else: 

135 raise ValueError( 

136 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 

137 ) 

138 else: 

139 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 

140 pbar_fn = pbar 

141 

142 return pbar_fn, pbar_kwargs 

143 

144 

145# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes 

146def run_maybe_parallel( 

147 func: Callable[[InputType], OutputType], 

148 iterable: Iterable[InputType], 

149 parallel: Union[bool, int], 

150 pbar_kwargs: Optional[Dict[str, Any]] = None, 

151 chunksize: Optional[int] = None, 

152 keep_ordered: bool = True, 

153 use_multiprocess: bool = False, 

154 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 

155) -> List[OutputType]: 

156 """a function to make it easier to sometimes parallelize an operation 

157 

158 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 

159 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 

160 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 

161 

162 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 

163 

164 # Parameters: 

165 - `func : Callable[[InputType], OutputType]` 

166 function passed to either `map` or `Pool.imap` 

167 - `iterable : Iterable[InputType]` 

168 iterable passed to either `map` or `Pool.imap` 

169 - `parallel : bool | int` 

170 whether to run in parallel, and how many processes to use 

171 - `pbar_kwargs : Dict[str, Any]` 

172 kwargs passed to the progress bar function 

173 

174 # Returns: 

175 - `List[OutputType]` 

176 a list of the output of `func` for each element in `iterable` 

177 

178 # Raises: 

179 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 

180 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 

181 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 

182 """ 

183 

184 # number of inputs in iterable 

185 n_inputs: int = len(iterable) # type: ignore[arg-type] 

186 if n_inputs == 0: 

187 # Return immediately if there is no input 

188 return list() 

189 

190 # which progress bar to use 

191 pbar_fn: ProgressBarFunction 

192 pbar_kwargs_processed: dict 

193 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 

194 pbar=pbar, 

195 pbar_kwargs=pbar_kwargs, 

196 # extra kwargs 

197 total=n_inputs, 

198 ) 

199 

200 # number of processes 

201 num_processes: int 

202 if isinstance(parallel, bool): 

203 num_processes = multiprocessing.cpu_count() if parallel else 1 

204 elif isinstance(parallel, int): 

205 if parallel < 2: 

206 raise ValueError( 

207 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 

208 ) 

209 num_processes = parallel 

210 else: 

211 raise ValueError( 

212 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 

213 ) 

214 

215 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 

216 num_processes = min(num_processes, n_inputs) 

217 mp = multiprocessing 

218 if num_processes == 1: 

219 parallel = False 

220 

221 if use_multiprocess: 

222 if not parallel: 

223 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 

224 

225 try: 

226 import multiprocess # type: ignore[import-untyped] 

227 except ImportError as e: 

228 raise ImportError( 

229 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 

230 ) from e 

231 

232 mp = multiprocess 

233 

234 # set up the map function -- maybe its parallel, maybe it's just `map` 

235 do_map: Callable[ 

236 [Callable[[InputType], OutputType], Iterable[InputType]], 

237 Iterable[OutputType], 

238 ] 

239 if parallel: 

240 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 

241 pool = mp.Pool(num_processes) 

242 

243 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 

244 if keep_ordered: 

245 do_map = pool.imap 

246 else: 

247 do_map = pool.imap_unordered 

248 

249 # figure out a smart chunksize if one is not given 

250 chunksize_int: int 

251 if chunksize is None: 

252 chunksize_int = max(1, n_inputs // num_processes) 

253 else: 

254 chunksize_int = chunksize 

255 

256 # set the chunksize 

257 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 

258 

259 else: 

260 do_map = map 

261 

262 # run the map function with a progress bar 

263 output: List[OutputType] = list( 

264 pbar_fn( 

265 do_map( 

266 func, 

267 iterable, 

268 ), 

269 **pbar_kwargs_processed, 

270 ) 

271 ) 

272 

273 # close the pool if we used one 

274 if parallel: 

275 pool.close() 

276 pool.join() 

277 

278 # return the output as a list 

279 return output