Coverage for muutils/spinner.py: 93%

147 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner 

2 

3using the base `Spinner` class while some code is running. 

4""" 

5 

6import os 

7import time 

8from dataclasses import dataclass, field 

9import threading 

10import sys 

11from functools import wraps 

12from typing import ( 

13 List, 

14 Dict, 

15 Callable, 

16 Any, 

17 Literal, 

18 Optional, 

19 TextIO, 

20 TypeVar, 

21 Sequence, 

22 Union, 

23 ContextManager, 

24) 

25import warnings 

26 

27DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any]) 

28"Define a generic type for the decorated function" 

29 

30 

31@dataclass 

32class SpinnerConfig: 

33 working: List[str] = field(default_factory=lambda: ["|", "/", "-", "\\"]) 

34 success: str = "✔️" 

35 fail: str = "❌" 

36 

37 def is_ascii(self) -> bool: 

38 "whether all characters are ascii" 

39 return all(s.isascii() for s in self.working + [self.success, self.fail]) 

40 

41 def eq_lens(self) -> bool: 

42 "whether all working characters are the same length" 

43 expected_len: int = len(self.working[0]) 

44 return all( 

45 [ 

46 len(char) == expected_len 

47 for char in self.working + [self.success, self.fail] 

48 ] 

49 ) 

50 

51 def is_valid(self) -> bool: 

52 "whether the spinner config is valid" 

53 return all( 

54 [ 

55 len(self.working) > 0, 

56 isinstance(self.working, list), 

57 isinstance(self.success, str), 

58 isinstance(self.fail, str), 

59 all(isinstance(char, str) for char in self.working), 

60 ] 

61 ) 

62 

63 def __post_init__(self): 

64 if not self.is_valid(): 

65 raise ValueError(f"Invalid SpinnerConfig: {self}") 

66 

67 @classmethod 

68 def from_any(cls, arg: "SpinnerConfigArg") -> "SpinnerConfig": 

69 if isinstance(arg, str): 

70 return SPINNERS[arg] 

71 elif isinstance(arg, list): 

72 return SpinnerConfig(working=arg) 

73 elif isinstance(arg, dict): 

74 return SpinnerConfig(**arg) 

75 elif isinstance(arg, SpinnerConfig): 

76 return arg 

77 else: 

78 raise TypeError( 

79 f"to create a SpinnerConfig, you must pass a string (key), list (working seq), dict (kwargs to SpinnerConfig), or SpinnerConfig, but got {type(arg) = }, {arg = }" 

80 ) 

81 

82 

83SpinnerConfigArg = Union[str, List[str], SpinnerConfig, dict] 

84 

85SPINNERS: Dict[str, SpinnerConfig] = dict( 

86 default=SpinnerConfig(working=["|", "/", "-", "\\"], success="#", fail="X"), 

87 dots=SpinnerConfig(working=[". ", ".. ", "..."], success="***", fail="xxx"), 

88 bars=SpinnerConfig(working=["| ", "|| ", "|||"], success="|||", fail="///"), 

89 arrows=SpinnerConfig(working=["<", "^", ">", "v"], success="►", fail="✖"), 

90 arrows_2=SpinnerConfig( 

91 working=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], success="→", fail="↯" 

92 ), 

93 bouncing_bar=SpinnerConfig( 

94 working=["[ ]", "[= ]", "[== ]", "[=== ]", "[ ===]", "[ ==]", "[ =]"], 

95 success="[====]", 

96 fail="[XXXX]", 

97 ), 

98 bar=SpinnerConfig( 

99 working=["[ ]", "[- ]", "[--]", "[ -]"], 

100 success="[==]", 

101 fail="[xx]", 

102 ), 

103 bouncing_ball=SpinnerConfig( 

104 working=[ 

105 "( ● )", 

106 "( ● )", 

107 "( ● )", 

108 "( ● )", 

109 "( ●)", 

110 "( ● )", 

111 "( ● )", 

112 "( ● )", 

113 "( ● )", 

114 "(● )", 

115 ], 

116 success="(●●●●●●)", 

117 fail="( ✖ )", 

118 ), 

119 ooo=SpinnerConfig(working=[".", "o", "O", "o"], success="O", fail="x"), 

120 braille=SpinnerConfig( 

121 working=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], 

122 success="⣿", 

123 fail="X", 

124 ), 

125 clock=SpinnerConfig( 

126 working=[ 

127 "🕛", 

128 "🕐", 

129 "🕑", 

130 "🕒", 

131 "🕓", 

132 "🕔", 

133 "🕕", 

134 "🕖", 

135 "🕗", 

136 "🕘", 

137 "🕙", 

138 "🕚", 

139 ], 

140 success="✔️", 

141 fail="❌", 

142 ), 

143 hourglass=SpinnerConfig(working=["⏳", "⌛"], success="✔️", fail="❌"), 

144 square_corners=SpinnerConfig(working=["◰", "◳", "◲", "◱"], success="◼", fail="✖"), 

145 triangle=SpinnerConfig(working=["◢", "◣", "◤", "◥"], success="◆", fail="✖"), 

146 square_dot=SpinnerConfig( 

147 working=["⣷", "⣯", "⣟", "⡿", "⢿", "⣻", "⣽", "⣾"], success="⣿", fail="❌" 

148 ), 

149 box_bounce=SpinnerConfig(working=["▌", "▀", "▐", "▄"], success="■", fail="✖"), 

150 hamburger=SpinnerConfig(working=["☱", "☲", "☴"], success="☰", fail="✖"), 

151 earth=SpinnerConfig(working=["🌍", "🌎", "🌏"], success="✔️", fail="❌"), 

152 growing_dots=SpinnerConfig( 

153 working=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"], success="⣿", fail="✖" 

154 ), 

155 dice=SpinnerConfig(working=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"], success="🎲", fail="✖"), 

156 wifi=SpinnerConfig( 

157 working=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"], success="✔️", fail="❌" 

158 ), 

159 bounce=SpinnerConfig(working=["⠁", "⠂", "⠄", "⠂"], success="⠿", fail="⢿"), 

160 arc=SpinnerConfig(working=["◜", "◠", "◝", "◞", "◡", "◟"], success="○", fail="✖"), 

161 toggle=SpinnerConfig(working=["⊶", "⊷"], success="⊷", fail="⊗"), 

162 toggle2=SpinnerConfig(working=["▫", "▪"], success="▪", fail="✖"), 

163 toggle3=SpinnerConfig(working=["□", "■"], success="■", fail="✖"), 

164 toggle4=SpinnerConfig(working=["■", "□", "▪", "▫"], success="■", fail="✖"), 

165 toggle5=SpinnerConfig(working=["▮", "▯"], success="▮", fail="✖"), 

166 toggle7=SpinnerConfig(working=["⦾", "⦿"], success="⦿", fail="✖"), 

167 toggle8=SpinnerConfig(working=["◍", "◌"], success="◍", fail="✖"), 

168 toggle9=SpinnerConfig(working=["◉", "◎"], success="◉", fail="✖"), 

169 arrow2=SpinnerConfig( 

170 working=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "], success="➡️", fail="❌" 

171 ), 

172 point=SpinnerConfig( 

173 working=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"], success="●●●", fail="xxx" 

174 ), 

175 layer=SpinnerConfig(working=["-", "=", "≡"], success="≡", fail="✖"), 

176 speaker=SpinnerConfig( 

177 working=["🔈 ", "🔉 ", "🔊 ", "🔉 "], success="🔊", fail="🔇" 

178 ), 

179 orangePulse=SpinnerConfig( 

180 working=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "], success="🟠", fail="❌" 

181 ), 

182 bluePulse=SpinnerConfig( 

183 working=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "], success="🔵", fail="❌" 

184 ), 

185 satellite_signal=SpinnerConfig( 

186 working=["📡 ", "📡· ", "📡·· ", "📡···", "📡 ··", "📡 ·"], 

187 success="📡 ✔️ ", 

188 fail="📡 ❌ ", 

189 ), 

190 rocket_orbit=SpinnerConfig( 

191 working=["🌍🚀 ", "🌏 🚀 ", "🌎 🚀"], success="🌍 ✨", fail="🌍 💥" 

192 ), 

193 ogham=SpinnerConfig(working=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"], success="᚛᚜", fail="✖"), 

194 eth=SpinnerConfig( 

195 working=["᛫", "፡", "፥", "፤", "፧", "።", "፨"], success="፠", fail="✖" 

196 ), 

197) 

198# spinner configurations 

199 

200 

201class Spinner: 

202 """displays a spinner, and optionally elapsed time and a mutable value while a function is running. 

203 

204 # Parameters: 

205 

206 - `update_interval : float` 

207 how often to update the spinner display in seconds 

208 (defaults to `0.1`) 

209 - `initial_value : str` 

210 initial value to display with the spinner 

211 (defaults to `""`) 

212 - `message : str` 

213 message to display with the spinner 

214 (defaults to `""`) 

215 - `format_string : str` 

216 string to format the spinner with. must have `"\\r"` prepended to clear the line. 

217 allowed keys are `spinner`, `elapsed_time`, `message`, and `value` 

218 (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`) 

219 - `output_stream : TextIO` 

220 stream to write the spinner to 

221 (defaults to `sys.stdout`) 

222 - `format_string_when_updated : Union[bool,str]` 

223 whether to use a different format string when the value is updated. 

224 if `True`, use the default format string with a newline appended. if a string, use that string. 

225 this is useful if you want update_value to print to console and be preserved. 

226 (defaults to `False`) 

227 

228 # Deprecated Parameters: 

229 

230 - `spinner_chars : Union[str, Sequence[str]]` 

231 sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters 

232 (defaults to `"default"`) 

233 - `spinner_complete : str` 

234 string to display when the spinner is complete 

235 (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`) 

236 

237 # Methods: 

238 - `update_value(value: Any) -> None` 

239 update the current value displayed by the spinner 

240 

241 # Usage: 

242 

243 ## As a context manager: 

244 ```python 

245 with SpinnerContext() as sp: 

246 for i in range(1): 

247 time.sleep(0.1) 

248 spinner.update_value(f"Step {i+1}") 

249 ``` 

250 

251 ## As a decorator: 

252 ```python 

253 @spinner_decorator 

254 def long_running_function(): 

255 for i in range(1): 

256 time.sleep(0.1) 

257 spinner.update_value(f"Step {i+1}") 

258 return "Function completed" 

259 ``` 

260 """ 

261 

262 def __init__( 

263 self, 

264 # no positional args 

265 *args, 

266 config: SpinnerConfigArg = "default", 

267 update_interval: float = 0.1, 

268 initial_value: str = "", 

269 message: str = "", 

270 format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}", 

271 output_stream: TextIO = sys.stdout, 

272 format_string_when_updated: Union[str, bool] = False, 

273 # deprecated 

274 spinner_chars: Optional[Union[str, Sequence[str]]] = None, 

275 spinner_complete: Optional[str] = None, 

276 # no other kwargs accepted 

277 **kwargs: Any, 

278 ): 

279 if args: 

280 raise ValueError(f"Spinner does not accept positional arguments: {args}") 

281 if kwargs: 

282 raise ValueError( 

283 f"Spinner did not recognize these keyword arguments: {kwargs}" 

284 ) 

285 

286 # old spinner display 

287 if (spinner_chars is not None) or (spinner_complete is not None): 

288 warnings.warn( 

289 "spinner_chars and spinner_complete are deprecated and will have no effect. Use `config` instead.", 

290 DeprecationWarning, 

291 ) 

292 

293 # config 

294 self.config: SpinnerConfig = SpinnerConfig.from_any(config) 

295 

296 # special format string for when the value is updated 

297 self.format_string_when_updated: Optional[str] = None 

298 "format string to use when the value is updated" 

299 if format_string_when_updated is not False: 

300 if format_string_when_updated is True: 

301 # modify the default format string 

302 self.format_string_when_updated = format_string + "\n" 

303 elif isinstance(format_string_when_updated, str): 

304 # use the provided format string 

305 self.format_string_when_updated = format_string_when_updated 

306 else: 

307 raise TypeError( 

308 "format_string_when_updated must be a string or True, got" 

309 + f" {type(format_string_when_updated) = }{format_string_when_updated}" 

310 ) 

311 

312 # copy other kwargs 

313 self.update_interval: float = update_interval 

314 self.message: str = message 

315 self.current_value: Any = initial_value 

316 self.format_string: str = format_string 

317 self.output_stream: TextIO = output_stream 

318 

319 # test out format string 

320 try: 

321 self.format_string.format( 

322 spinner=self.config.working[0], 

323 elapsed_time=0.0, 

324 message=self.message, 

325 value=self.current_value, 

326 ) 

327 except Exception as e: 

328 raise ValueError( 

329 f"Invalid format string: {format_string}. Must take keys " 

330 + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'." 

331 ) from e 

332 

333 # init 

334 self.start_time: float = 0 

335 "for measuring elapsed time" 

336 self.stop_spinner: threading.Event = threading.Event() 

337 "to stop the spinner" 

338 self.spinner_thread: Optional[threading.Thread] = None 

339 "the thread running the spinner" 

340 self.value_changed: bool = False 

341 "whether the value has been updated since the last display" 

342 self.term_width: int 

343 "width of the terminal, for padding with spaces" 

344 try: 

345 self.term_width = os.get_terminal_size().columns 

346 except OSError: 

347 self.term_width = 80 

348 

349 # state of the spinner 

350 self.state: Literal["initialized", "running", "success", "fail"] = "initialized" 

351 

352 def spin(self) -> None: 

353 "Function to run in a separate thread, displaying the spinner and optional information" 

354 i: int = 0 

355 while not self.stop_spinner.is_set(): 

356 # get current spinner str 

357 spinner: str = self.config.working[i % len(self.config.working)] 

358 

359 # args for display string 

360 display_parts: Dict[str, Any] = dict( 

361 spinner=spinner, # str 

362 elapsed_time=time.time() - self.start_time, # float 

363 message=self.message, # str 

364 value=self.current_value, # Any, but will be formatted as str 

365 ) 

366 

367 # use the special one if needed 

368 format_str: str = self.format_string 

369 if self.value_changed and (self.format_string_when_updated is not None): 

370 self.value_changed = False 

371 format_str = self.format_string_when_updated 

372 

373 # write and flush the display string 

374 output: str = format_str.format(**display_parts).ljust(self.term_width) 

375 self.output_stream.write(output) 

376 self.output_stream.flush() 

377 

378 # wait for the next update 

379 time.sleep(self.update_interval) 

380 i += 1 

381 

382 def update_value(self, value: Any) -> None: 

383 "Update the current value displayed by the spinner" 

384 self.current_value = value 

385 self.value_changed = True 

386 

387 def start(self) -> None: 

388 "Start the spinner" 

389 self.start_time = time.time() 

390 self.spinner_thread = threading.Thread(target=self.spin) 

391 self.spinner_thread.start() 

392 self.state = "running" 

393 

394 def stop(self, failed: bool = False) -> None: 

395 "Stop the spinner" 

396 self.output_stream.write( 

397 self.format_string.format( 

398 spinner=self.config.success if not failed else self.config.fail, 

399 elapsed_time=time.time() - self.start_time, # float 

400 message=self.message, # str 

401 value=self.current_value, # Any, but will be formatted as str 

402 ).ljust(self.term_width) 

403 ) 

404 self.stop_spinner.set() 

405 if self.spinner_thread: 

406 self.spinner_thread.join() 

407 self.output_stream.write("\n") 

408 self.output_stream.flush() 

409 

410 self.state = "fail" if failed else "success" 

411 

412 

413class NoOpContextManager(ContextManager): 

414 """A context manager that does nothing.""" 

415 

416 def __init__(self, *args, **kwargs): 

417 pass 

418 

419 def __enter__(self): 

420 return self 

421 

422 def __exit__(self, exc_type, exc_value, traceback): 

423 pass 

424 

425 

426class SpinnerContext(Spinner, ContextManager): 

427 "see `Spinner` for parameters" 

428 

429 def __enter__(self) -> "SpinnerContext": 

430 self.start() 

431 return self 

432 

433 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

434 self.stop(failed=exc_type is not None) 

435 

436 

437SpinnerContext.__doc__ = Spinner.__doc__ 

438 

439 

440# TODO: type hint that the `update_status` kwarg is not needed when calling the function we just decorated 

441def spinner_decorator( 

442 *args, 

443 # passed to `Spinner.__init__` 

444 config: SpinnerConfigArg = "default", 

445 update_interval: float = 0.1, 

446 initial_value: str = "", 

447 message: str = "", 

448 format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}", 

449 output_stream: TextIO = sys.stdout, 

450 # new kwarg 

451 mutable_kwarg_key: Optional[str] = None, 

452 # deprecated 

453 spinner_chars: Union[str, Sequence[str], None] = None, 

454 spinner_complete: Optional[str] = None, 

455 **kwargs, 

456) -> Callable[[DecoratedFunction], DecoratedFunction]: 

457 """see `Spinner` for parameters. Also takes `mutable_kwarg_key` 

458 

459 `mutable_kwarg_key` is the key with which `Spinner().update_value` 

460 will be passed to the decorated function. if `None`, won't pass it. 

461 

462 """ 

463 

464 if len(args) > 1: 

465 raise ValueError( 

466 f"spinner_decorator does not accept positional arguments: {args}" 

467 ) 

468 if kwargs: 

469 raise ValueError( 

470 f"spinner_decorator did not recognize these keyword arguments: {kwargs}" 

471 ) 

472 

473 def decorator(func: DecoratedFunction) -> DecoratedFunction: 

474 @wraps(func) 

475 def wrapper(*args: Any, **kwargs: Any) -> Any: 

476 spinner: Spinner = Spinner( 

477 config=config, 

478 update_interval=update_interval, 

479 initial_value=initial_value, 

480 message=message, 

481 format_string=format_string, 

482 output_stream=output_stream, 

483 spinner_chars=spinner_chars, 

484 spinner_complete=spinner_complete, 

485 ) 

486 

487 if mutable_kwarg_key: 

488 kwargs[mutable_kwarg_key] = spinner.update_value 

489 

490 spinner.start() 

491 try: 

492 result: Any = func(*args, **kwargs) 

493 spinner.stop(failed=False) 

494 except Exception as e: 

495 spinner.stop(failed=True) 

496 raise e 

497 

498 return result 

499 

500 # TODO: fix this type ignore 

501 return wrapper # type: ignore[return-value] 

502 

503 if not args: 

504 # called as `@spinner_decorator(stuff)` 

505 return decorator 

506 else: 

507 # called as `@spinner_decorator` without parens 

508 return decorator(args[0]) 

509 

510 

511spinner_decorator.__doc__ = Spinner.__doc__