Coverage for muutils / spinner.py: 93%

149 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1"""decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner 

2 

3using the base `Spinner` class while some code is running. 

4""" 

5 

6from __future__ import annotations 

7 

8import os 

9import time 

10from dataclasses import dataclass, field 

11import threading 

12import sys 

13from functools import wraps 

14from types import TracebackType 

15from typing import ( 

16 List, 

17 Dict, 

18 Callable, 

19 Any, 

20 Literal, 

21 Optional, 

22 TextIO, 

23 TypeVar, 

24 Sequence, 

25 Union, 

26 ContextManager, 

27) 

28import warnings 

29 

30DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any]) 

31"Define a generic type for the decorated function" 

32 

33 

34@dataclass 

35class SpinnerConfig: 

36 working: List[str] = field(default_factory=lambda: ["|", "/", "-", "\\"]) 

37 success: str = "✔️" 

38 fail: str = "❌" 

39 

40 def is_ascii(self) -> bool: 

41 "whether all characters are ascii" 

42 return all(s.isascii() for s in self.working + [self.success, self.fail]) 

43 

44 def eq_lens(self) -> bool: 

45 "whether all working characters are the same length" 

46 expected_len: int = len(self.working[0]) 

47 return all( 

48 [ 

49 len(char) == expected_len 

50 for char in self.working + [self.success, self.fail] 

51 ] 

52 ) 

53 

54 def is_valid(self) -> bool: 

55 "whether the spinner config is valid" 

56 return all( 

57 [ 

58 len(self.working) > 0, 

59 isinstance(self.working, list), 

60 isinstance(self.success, str), 

61 isinstance(self.fail, str), 

62 all(isinstance(char, str) for char in self.working), 

63 ] 

64 ) 

65 

66 def __post_init__(self): 

67 if not self.is_valid(): 

68 raise ValueError(f"Invalid SpinnerConfig: {self}") 

69 

70 @classmethod 

71 def from_any(cls, arg: "SpinnerConfigArg") -> "SpinnerConfig": 

72 # check SpinnerConfig first to help type narrowing 

73 if isinstance(arg, SpinnerConfig): 

74 return arg 

75 elif isinstance(arg, str): 

76 return SPINNERS[arg] 

77 elif isinstance(arg, list): 

78 return SpinnerConfig(working=arg) 

79 elif isinstance(arg, dict): 

80 return SpinnerConfig(**arg) 

81 else: 

82 raise TypeError( 

83 f"to create a SpinnerConfig, you must pass a string (key), list (working seq), dict (kwargs to SpinnerConfig), or SpinnerConfig, but got {type(arg) = }, {arg = }" 

84 ) 

85 

86 

87SpinnerConfigArg = Union[str, List[str], SpinnerConfig, Dict[str, Any]] 

88 

89SPINNERS: Dict[str, SpinnerConfig] = dict( 

90 default=SpinnerConfig(working=["|", "/", "-", "\\"], success="#", fail="X"), 

91 dots=SpinnerConfig(working=[". ", ".. ", "..."], success="***", fail="xxx"), 

92 bars=SpinnerConfig(working=["| ", "|| ", "|||"], success="|||", fail="///"), 

93 arrows=SpinnerConfig(working=["<", "^", ">", "v"], success="►", fail="✖"), 

94 arrows_2=SpinnerConfig( 

95 working=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], success="→", fail="↯" 

96 ), 

97 bouncing_bar=SpinnerConfig( 

98 working=["[ ]", "[= ]", "[== ]", "[=== ]", "[ ===]", "[ ==]", "[ =]"], 

99 success="[====]", 

100 fail="[XXXX]", 

101 ), 

102 bar=SpinnerConfig( 

103 working=["[ ]", "[- ]", "[--]", "[ -]"], 

104 success="[==]", 

105 fail="[xx]", 

106 ), 

107 bouncing_ball=SpinnerConfig( 

108 working=[ 

109 "( ● )", 

110 "( ● )", 

111 "( ● )", 

112 "( ● )", 

113 "( ●)", 

114 "( ● )", 

115 "( ● )", 

116 "( ● )", 

117 "( ● )", 

118 "(● )", 

119 ], 

120 success="(●●●●●●)", 

121 fail="( ✖ )", 

122 ), 

123 ooo=SpinnerConfig(working=[".", "o", "O", "o"], success="O", fail="x"), 

124 braille=SpinnerConfig( 

125 working=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], 

126 success="⣿", 

127 fail="X", 

128 ), 

129 clock=SpinnerConfig( 

130 working=[ 

131 "🕛", 

132 "🕐", 

133 "🕑", 

134 "🕒", 

135 "🕓", 

136 "🕔", 

137 "🕕", 

138 "🕖", 

139 "🕗", 

140 "🕘", 

141 "🕙", 

142 "🕚", 

143 ], 

144 success="✔️", 

145 fail="❌", 

146 ), 

147 hourglass=SpinnerConfig(working=["⏳", "⌛"], success="✔️", fail="❌"), 

148 square_corners=SpinnerConfig(working=["◰", "◳", "◲", "◱"], success="◼", fail="✖"), 

149 triangle=SpinnerConfig(working=["◢", "◣", "◤", "◥"], success="◆", fail="✖"), 

150 square_dot=SpinnerConfig( 

151 working=["⣷", "⣯", "⣟", "⡿", "⢿", "⣻", "⣽", "⣾"], success="⣿", fail="❌" 

152 ), 

153 box_bounce=SpinnerConfig(working=["▌", "▀", "▐", "▄"], success="■", fail="✖"), 

154 hamburger=SpinnerConfig(working=["☱", "☲", "☴"], success="☰", fail="✖"), 

155 earth=SpinnerConfig(working=["🌍", "🌎", "🌏"], success="✔️", fail="❌"), 

156 growing_dots=SpinnerConfig( 

157 working=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"], success="⣿", fail="✖" 

158 ), 

159 dice=SpinnerConfig(working=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"], success="🎲", fail="✖"), 

160 wifi=SpinnerConfig( 

161 working=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"], success="✔️", fail="❌" 

162 ), 

163 bounce=SpinnerConfig(working=["⠁", "⠂", "⠄", "⠂"], success="⠿", fail="⢿"), 

164 arc=SpinnerConfig(working=["◜", "◠", "◝", "◞", "◡", "◟"], success="○", fail="✖"), 

165 toggle=SpinnerConfig(working=["⊶", "⊷"], success="⊷", fail="⊗"), 

166 toggle2=SpinnerConfig(working=["▫", "▪"], success="▪", fail="✖"), 

167 toggle3=SpinnerConfig(working=["□", "■"], success="■", fail="✖"), 

168 toggle4=SpinnerConfig(working=["■", "□", "▪", "▫"], success="■", fail="✖"), 

169 toggle5=SpinnerConfig(working=["▮", "▯"], success="▮", fail="✖"), 

170 toggle7=SpinnerConfig(working=["⦾", "⦿"], success="⦿", fail="✖"), 

171 toggle8=SpinnerConfig(working=["◍", "◌"], success="◍", fail="✖"), 

172 toggle9=SpinnerConfig(working=["◉", "◎"], success="◉", fail="✖"), 

173 arrow2=SpinnerConfig( 

174 working=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "], success="➡️", fail="❌" 

175 ), 

176 point=SpinnerConfig( 

177 working=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"], success="●●●", fail="xxx" 

178 ), 

179 layer=SpinnerConfig(working=["-", "=", "≡"], success="≡", fail="✖"), 

180 speaker=SpinnerConfig( 

181 working=["🔈 ", "🔉 ", "🔊 ", "🔉 "], success="🔊", fail="🔇" 

182 ), 

183 orangePulse=SpinnerConfig( 

184 working=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "], success="🟠", fail="❌" 

185 ), 

186 bluePulse=SpinnerConfig( 

187 working=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "], success="🔵", fail="❌" 

188 ), 

189 satellite_signal=SpinnerConfig( 

190 working=["📡 ", "📡· ", "📡·· ", "📡···", "📡 ··", "📡 ·"], 

191 success="📡 ✔️ ", 

192 fail="📡 ❌ ", 

193 ), 

194 rocket_orbit=SpinnerConfig( 

195 working=["🌍🚀 ", "🌏 🚀 ", "🌎 🚀"], success="🌍 ✨", fail="🌍 💥" 

196 ), 

197 ogham=SpinnerConfig(working=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"], success="᚛᚜", fail="✖"), 

198 eth=SpinnerConfig( 

199 working=["᛫", "፡", "፥", "፤", "፧", "።", "፨"], success="፠", fail="✖" 

200 ), 

201) 

202# spinner configurations 

203 

204 

205class Spinner: 

206 """displays a spinner, and optionally elapsed time and a mutable value while a function is running. 

207 

208 # Parameters: 

209 

210 - `update_interval : float` 

211 how often to update the spinner display in seconds 

212 (defaults to `0.1`) 

213 - `initial_value : str` 

214 initial value to display with the spinner 

215 (defaults to `""`) 

216 - `message : str` 

217 message to display with the spinner 

218 (defaults to `""`) 

219 - `format_string : str` 

220 string to format the spinner with. must have `"\\r"` prepended to clear the line. 

221 allowed keys are `spinner`, `elapsed_time`, `message`, and `value` 

222 (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`) 

223 - `output_stream : TextIO` 

224 stream to write the spinner to 

225 (defaults to `sys.stdout`) 

226 - `format_string_when_updated : Union[bool,str]` 

227 whether to use a different format string when the value is updated. 

228 if `True`, use the default format string with a newline appended. if a string, use that string. 

229 this is useful if you want update_value to print to console and be preserved. 

230 (defaults to `False`) 

231 

232 # Deprecated Parameters: 

233 

234 - `spinner_chars : Union[str, Sequence[str]]` 

235 sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters 

236 (defaults to `"default"`) 

237 - `spinner_complete : str` 

238 string to display when the spinner is complete 

239 (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`) 

240 

241 # Methods: 

242 - `update_value(value: Any) -> None` 

243 update the current value displayed by the spinner 

244 

245 # Usage: 

246 

247 ## As a context manager: 

248 ```python 

249 with SpinnerContext() as sp: 

250 for i in range(1): 

251 time.sleep(0.1) 

252 spinner.update_value(f"Step {i+1}") 

253 ``` 

254 

255 ## As a decorator: 

256 ```python 

257 @spinner_decorator 

258 def long_running_function(): 

259 for i in range(1): 

260 time.sleep(0.1) 

261 spinner.update_value(f"Step {i+1}") 

262 return "Function completed" 

263 ``` 

264 """ 

265 

266 def __init__( 

267 self, 

268 # no positional args 

269 *args: Any, 

270 config: SpinnerConfigArg = "default", 

271 update_interval: float = 0.1, 

272 initial_value: str = "", 

273 message: str = "", 

274 format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}", 

275 output_stream: TextIO = sys.stdout, 

276 format_string_when_updated: Union[str, bool] = False, 

277 # deprecated 

278 spinner_chars: Optional[Union[str, Sequence[str]]] = None, 

279 spinner_complete: Optional[str] = None, 

280 # no other kwargs accepted 

281 **kwargs: Any, 

282 ): 

283 if args: 

284 raise ValueError(f"Spinner does not accept positional arguments: {args}") 

285 if kwargs: 

286 raise ValueError( 

287 f"Spinner did not recognize these keyword arguments: {kwargs}" 

288 ) 

289 

290 # old spinner display 

291 if (spinner_chars is not None) or (spinner_complete is not None): 

292 warnings.warn( 

293 "spinner_chars and spinner_complete are deprecated and will have no effect. Use `config` instead.", 

294 DeprecationWarning, 

295 ) 

296 

297 # config 

298 self.config: SpinnerConfig = SpinnerConfig.from_any(config) 

299 

300 # special format string for when the value is updated 

301 self.format_string_when_updated: Optional[str] = None 

302 "format string to use when the value is updated" 

303 if format_string_when_updated is not False: 

304 if format_string_when_updated is True: 

305 # modify the default format string 

306 self.format_string_when_updated = format_string + "\n" 

307 elif isinstance(format_string_when_updated, str): 

308 # use the provided format string 

309 self.format_string_when_updated = format_string_when_updated 

310 else: 

311 raise TypeError( 

312 "format_string_when_updated must be a string or True, got" 

313 + f" {type(format_string_when_updated) = }{format_string_when_updated}" 

314 ) 

315 

316 # copy other kwargs 

317 self.update_interval: float = update_interval 

318 self.message: str = message 

319 self.current_value: Any = initial_value 

320 self.format_string: str = format_string 

321 self.output_stream: TextIO = output_stream 

322 

323 # test out format string 

324 try: 

325 self.format_string.format( 

326 spinner=self.config.working[0], 

327 elapsed_time=0.0, 

328 message=self.message, 

329 value=self.current_value, 

330 ) 

331 except Exception as e: 

332 raise ValueError( 

333 f"Invalid format string: {format_string}. Must take keys " 

334 + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'." 

335 ) from e 

336 

337 # init 

338 self.start_time: float = 0 

339 "for measuring elapsed time" 

340 self.stop_spinner: threading.Event = threading.Event() 

341 "to stop the spinner" 

342 self.spinner_thread: Optional[threading.Thread] = None 

343 "the thread running the spinner" 

344 self.value_changed: bool = False 

345 "whether the value has been updated since the last display" 

346 self.term_width: int 

347 "width of the terminal, for padding with spaces" 

348 try: 

349 self.term_width = os.get_terminal_size().columns 

350 except OSError: 

351 self.term_width = 80 

352 

353 # state of the spinner 

354 self.state: Literal["initialized", "running", "success", "fail"] = "initialized" 

355 

356 def spin(self) -> None: 

357 "Function to run in a separate thread, displaying the spinner and optional information" 

358 i: int = 0 

359 while not self.stop_spinner.is_set(): 

360 # get current spinner str 

361 spinner: str = self.config.working[i % len(self.config.working)] 

362 

363 # args for display string 

364 display_parts: Dict[str, Any] = dict( 

365 spinner=spinner, # str 

366 elapsed_time=time.time() - self.start_time, # float 

367 message=self.message, # str 

368 value=self.current_value, # Any, but will be formatted as str 

369 ) 

370 

371 # use the special one if needed 

372 format_str: str = self.format_string 

373 if self.value_changed and (self.format_string_when_updated is not None): 

374 self.value_changed = False 

375 format_str = self.format_string_when_updated 

376 

377 # write and flush the display string 

378 output: str = format_str.format(**display_parts).ljust(self.term_width) 

379 self.output_stream.write(output) 

380 self.output_stream.flush() 

381 

382 # wait for the next update 

383 time.sleep(self.update_interval) 

384 i += 1 

385 

386 def update_value(self, value: Any) -> None: 

387 "Update the current value displayed by the spinner" 

388 self.current_value = value 

389 self.value_changed = True 

390 

391 def start(self) -> None: 

392 "Start the spinner" 

393 self.start_time = time.time() 

394 self.spinner_thread = threading.Thread(target=self.spin) 

395 self.spinner_thread.start() 

396 self.state = "running" 

397 

398 def stop(self, failed: bool = False) -> None: 

399 "Stop the spinner" 

400 self.output_stream.write( 

401 self.format_string.format( 

402 spinner=self.config.success if not failed else self.config.fail, 

403 elapsed_time=time.time() - self.start_time, # float 

404 message=self.message, # str 

405 value=self.current_value, # Any, but will be formatted as str 

406 ).ljust(self.term_width) 

407 ) 

408 self.stop_spinner.set() 

409 if self.spinner_thread: 

410 self.spinner_thread.join() 

411 self.output_stream.write("\n") 

412 self.output_stream.flush() 

413 

414 self.state = "fail" if failed else "success" 

415 

416 

417class NoOpContextManager(ContextManager): # type: ignore[type-arg] 

418 """A context manager that does nothing.""" 

419 

420 def __init__(self, *args: Any, **kwargs: Any) -> None: 

421 pass 

422 

423 def __enter__(self) -> NoOpContextManager: 

424 return self 

425 

426 def __exit__( 

427 self, 

428 exc_type: type[BaseException] | None, 

429 exc_value: BaseException | None, 

430 traceback: TracebackType | None, 

431 ) -> None: 

432 pass 

433 

434 

435class SpinnerContext(Spinner, ContextManager): 

436 "see `Spinner` for parameters" 

437 

438 def __enter__(self) -> "SpinnerContext": 

439 self.start() 

440 return self 

441 

442 def __exit__( 

443 self, 

444 exc_type: type[BaseException] | None, 

445 exc_val: BaseException | None, 

446 exc_tb: TracebackType | None, 

447 ) -> None: 

448 self.stop(failed=exc_type is not None) 

449 

450 

451SpinnerContext.__doc__ = Spinner.__doc__ 

452 

453 

454# TODO: type hint that the `update_status` kwarg is not needed when calling the function we just decorated 

455def spinner_decorator( 

456 *args: Any, 

457 # passed to `Spinner.__init__` 

458 config: SpinnerConfigArg = "default", 

459 update_interval: float = 0.1, 

460 initial_value: str = "", 

461 message: str = "", 

462 format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}", 

463 output_stream: TextIO = sys.stdout, 

464 # new kwarg 

465 mutable_kwarg_key: Optional[str] = None, 

466 # deprecated 

467 spinner_chars: Union[str, Sequence[str], None] = None, 

468 spinner_complete: Optional[str] = None, 

469 **kwargs: Any, 

470) -> Callable[[DecoratedFunction], DecoratedFunction]: 

471 """see `Spinner` for parameters. Also takes `mutable_kwarg_key` 

472 

473 `mutable_kwarg_key` is the key with which `Spinner().update_value` 

474 will be passed to the decorated function. if `None`, won't pass it. 

475 

476 """ 

477 

478 if len(args) > 1: 

479 raise ValueError( 

480 f"spinner_decorator does not accept positional arguments: {args}" 

481 ) 

482 if kwargs: 

483 raise ValueError( 

484 f"spinner_decorator did not recognize these keyword arguments: {kwargs}" 

485 ) 

486 

487 def decorator(func: DecoratedFunction) -> DecoratedFunction: 

488 @wraps(func) 

489 def wrapper(*args: Any, **kwargs: Any) -> Any: 

490 spinner: Spinner = Spinner( 

491 config=config, 

492 update_interval=update_interval, 

493 initial_value=initial_value, 

494 message=message, 

495 format_string=format_string, 

496 output_stream=output_stream, 

497 spinner_chars=spinner_chars, 

498 spinner_complete=spinner_complete, 

499 ) 

500 

501 if mutable_kwarg_key: 

502 kwargs[mutable_kwarg_key] = spinner.update_value 

503 

504 spinner.start() 

505 try: 

506 result: Any = func(*args, **kwargs) 

507 spinner.stop(failed=False) 

508 except Exception as e: 

509 spinner.stop(failed=True) 

510 raise e 

511 

512 return result 

513 

514 # TODO: fix this type ignore 

515 return wrapper # type: ignore[return-value] 

516 

517 if not args: 

518 # called as `@spinner_decorator(stuff)` 

519 return decorator 

520 else: 

521 # called as `@spinner_decorator` without parens 

522 return decorator(args[0]) 

523 

524 

525spinner_decorator.__doc__ = Spinner.__doc__