Coverage for muutils/parallel.py: 93%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1import multiprocessing 

2import functools 

3from typing import ( 

4 Any, 

5 Callable, 

6 Iterable, 

7 Literal, 

8 Optional, 

9 Tuple, 

10 TypeVar, 

11 Dict, 

12 List, 

13 Union, 

14 Protocol, 

15) 

16 

17# for no tqdm fallback 

18from muutils.spinner import SpinnerContext 

19from muutils.validate_type import get_fn_allowed_kwargs 

20 

21 

22InputType = TypeVar("InputType") 

23OutputType = TypeVar("OutputType") 

24# typevars for our iterable and map 

25 

26 

27class ProgressBarFunction(Protocol): 

28 "a protocol for a progress bar function" 

29 

30 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 

31 

32 

33ProgressBarOption = Literal["tqdm", "spinner", "none", None] 

34# type for the progress bar option 

35 

36 

37DEFAULT_PBAR_FN: ProgressBarOption 

38# default progress bar function 

39 

40try: 

41 # use tqdm if it's available 

42 import tqdm # type: ignore[import-untyped] 

43 

44 DEFAULT_PBAR_FN = "tqdm" 

45 

46except ImportError: 

47 # use progress bar as fallback 

48 DEFAULT_PBAR_FN = "spinner" 

49 

50 

51def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 

52 "spinner wrapper" 

53 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 

54 SpinnerContext.__init__ 

55 ) 

56 mapped_kwargs: dict = { 

57 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 

58 } 

59 if "desc" in kwargs and "message" not in mapped_kwargs: 

60 mapped_kwargs["message"] = kwargs["desc"] 

61 

62 if "message" not in mapped_kwargs and "total" in kwargs: 

63 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 

64 

65 with SpinnerContext(**mapped_kwargs): 

66 output = list(x) 

67 

68 return output 

69 

70 

71def map_kwargs_for_tqdm(kwargs: dict) -> dict: 

72 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 

73 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 

74 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 

75 

76 if "desc" not in kwargs: 

77 if "message" in kwargs: 

78 mapped_kwargs["desc"] = kwargs["message"] 

79 

80 elif "total" in kwargs: 

81 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 

82 return mapped_kwargs 

83 

84 

85def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 

86 "fallback to no progress bar" 

87 return x 

88 

89 

90def set_up_progress_bar_fn( 

91 pbar: Union[ProgressBarFunction, ProgressBarOption], 

92 pbar_kwargs: Optional[Dict[str, Any]] = None, 

93 **extra_kwargs, 

94) -> Tuple[ProgressBarFunction, dict]: 

95 """set up the progress bar function and its kwargs 

96 

97 # Parameters: 

98 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 

99 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 

100 - `pbar_kwargs : Optional[Dict[str, Any]]` 

101 kwargs passed to the progress bar function (default to `None`) 

102 (defaults to `None`) 

103 

104 # Returns: 

105 - `Tuple[ProgressBarFunction, dict]` 

106 a tuple of the progress bar function and its kwargs 

107 

108 # Raises: 

109 - `ValueError` : if `pbar` is not one of the valid options 

110 """ 

111 pbar_fn: ProgressBarFunction 

112 

113 if pbar_kwargs is None: 

114 pbar_kwargs = dict() 

115 

116 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 

117 

118 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 

119 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 

120 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 

121 

122 # if `pbar` is a different string, figure out which progress bar to use 

123 elif isinstance(pbar, str): 

124 if pbar == "tqdm": 

125 pbar_fn = tqdm.tqdm 

126 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 

127 elif pbar == "spinner": 

128 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 

129 pbar_kwargs = dict() 

130 else: 

131 raise ValueError( 

132 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 

133 ) 

134 else: 

135 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 

136 pbar_fn = pbar 

137 

138 return pbar_fn, pbar_kwargs 

139 

140 

141# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes 

142def run_maybe_parallel( 

143 func: Callable[[InputType], OutputType], 

144 iterable: Iterable[InputType], 

145 parallel: Union[bool, int], 

146 pbar_kwargs: Optional[Dict[str, Any]] = None, 

147 chunksize: Optional[int] = None, 

148 keep_ordered: bool = True, 

149 use_multiprocess: bool = False, 

150 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 

151) -> List[OutputType]: 

152 """a function to make it easier to sometimes parallelize an operation 

153 

154 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 

155 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 

156 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 

157 

158 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 

159 

160 # Parameters: 

161 - `func : Callable[[InputType], OutputType]` 

162 function passed to either `map` or `Pool.imap` 

163 - `iterable : Iterable[InputType]` 

164 iterable passed to either `map` or `Pool.imap` 

165 - `parallel : bool | int` 

166 whether to run in parallel, and how many processes to use 

167 - `pbar_kwargs : Dict[str, Any]` 

168 kwargs passed to the progress bar function 

169 

170 # Returns: 

171 - `List[OutputType]` 

172 a list of the output of `func` for each element in `iterable` 

173 

174 # Raises: 

175 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 

176 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 

177 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 

178 """ 

179 

180 # number of inputs in iterable 

181 n_inputs: int = len(iterable) # type: ignore[arg-type] 

182 if n_inputs == 0: 

183 # Return immediately if there is no input 

184 return list() 

185 

186 # which progress bar to use 

187 pbar_fn: ProgressBarFunction 

188 pbar_kwargs_processed: dict 

189 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 

190 pbar=pbar, 

191 pbar_kwargs=pbar_kwargs, 

192 # extra kwargs 

193 total=n_inputs, 

194 ) 

195 

196 # number of processes 

197 num_processes: int 

198 if isinstance(parallel, bool): 

199 num_processes = multiprocessing.cpu_count() if parallel else 1 

200 elif isinstance(parallel, int): 

201 if parallel < 2: 

202 raise ValueError( 

203 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 

204 ) 

205 num_processes = parallel 

206 else: 

207 raise ValueError( 

208 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 

209 ) 

210 

211 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 

212 num_processes = min(num_processes, n_inputs) 

213 mp = multiprocessing 

214 if num_processes == 1: 

215 parallel = False 

216 

217 if use_multiprocess: 

218 if not parallel: 

219 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 

220 

221 try: 

222 import multiprocess # type: ignore[import-untyped] 

223 except ImportError as e: 

224 raise ImportError( 

225 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 

226 ) from e 

227 

228 mp = multiprocess 

229 

230 # set up the map function -- maybe its parallel, maybe it's just `map` 

231 do_map: Callable[ 

232 [Callable[[InputType], OutputType], Iterable[InputType]], 

233 Iterable[OutputType], 

234 ] 

235 if parallel: 

236 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 

237 pool = mp.Pool(num_processes) 

238 

239 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 

240 if keep_ordered: 

241 do_map = pool.imap 

242 else: 

243 do_map = pool.imap_unordered 

244 

245 # figure out a smart chunksize if one is not given 

246 chunksize_int: int 

247 if chunksize is None: 

248 chunksize_int = max(1, n_inputs // num_processes) 

249 else: 

250 chunksize_int = chunksize 

251 

252 # set the chunksize 

253 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 

254 

255 else: 

256 do_map = map 

257 

258 # run the map function with a progress bar 

259 output: List[OutputType] = list( 

260 pbar_fn( 

261 do_map( 

262 func, 

263 iterable, 

264 ), 

265 **pbar_kwargs_processed, 

266 ) 

267 ) 

268 

269 # close the pool if we used one 

270 if parallel: 

271 pool.close() 

272 pool.join() 

273 

274 # return the output as a list 

275 return output