Coverage for muutils/spinner.py: 93%
147 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner
3using the base `Spinner` class while some code is running.
4"""
6import os
7import time
8from dataclasses import dataclass, field
9import threading
10import sys
11from functools import wraps
12from typing import (
13 List,
14 Dict,
15 Callable,
16 Any,
17 Literal,
18 Optional,
19 TextIO,
20 TypeVar,
21 Sequence,
22 Union,
23 ContextManager,
24)
25import warnings
27DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any])
28"Define a generic type for the decorated function"
31@dataclass
32class SpinnerConfig:
33 working: List[str] = field(default_factory=lambda: ["|", "/", "-", "\\"])
34 success: str = "✔️"
35 fail: str = "❌"
37 def is_ascii(self) -> bool:
38 "whether all characters are ascii"
39 return all(s.isascii() for s in self.working + [self.success, self.fail])
41 def eq_lens(self) -> bool:
42 "whether all working characters are the same length"
43 expected_len: int = len(self.working[0])
44 return all(
45 [
46 len(char) == expected_len
47 for char in self.working + [self.success, self.fail]
48 ]
49 )
51 def is_valid(self) -> bool:
52 "whether the spinner config is valid"
53 return all(
54 [
55 len(self.working) > 0,
56 isinstance(self.working, list),
57 isinstance(self.success, str),
58 isinstance(self.fail, str),
59 all(isinstance(char, str) for char in self.working),
60 ]
61 )
63 def __post_init__(self):
64 if not self.is_valid():
65 raise ValueError(f"Invalid SpinnerConfig: {self}")
67 @classmethod
68 def from_any(cls, arg: "SpinnerConfigArg") -> "SpinnerConfig":
69 if isinstance(arg, str):
70 return SPINNERS[arg]
71 elif isinstance(arg, list):
72 return SpinnerConfig(working=arg)
73 elif isinstance(arg, dict):
74 return SpinnerConfig(**arg)
75 elif isinstance(arg, SpinnerConfig):
76 return arg
77 else:
78 raise TypeError(
79 f"to create a SpinnerConfig, you must pass a string (key), list (working seq), dict (kwargs to SpinnerConfig), or SpinnerConfig, but got {type(arg) = }, {arg = }"
80 )
83SpinnerConfigArg = Union[str, List[str], SpinnerConfig, dict]
85SPINNERS: Dict[str, SpinnerConfig] = dict(
86 default=SpinnerConfig(working=["|", "/", "-", "\\"], success="#", fail="X"),
87 dots=SpinnerConfig(working=[". ", ".. ", "..."], success="***", fail="xxx"),
88 bars=SpinnerConfig(working=["| ", "|| ", "|||"], success="|||", fail="///"),
89 arrows=SpinnerConfig(working=["<", "^", ">", "v"], success="►", fail="✖"),
90 arrows_2=SpinnerConfig(
91 working=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], success="→", fail="↯"
92 ),
93 bouncing_bar=SpinnerConfig(
94 working=["[ ]", "[= ]", "[== ]", "[=== ]", "[ ===]", "[ ==]", "[ =]"],
95 success="[====]",
96 fail="[XXXX]",
97 ),
98 bar=SpinnerConfig(
99 working=["[ ]", "[- ]", "[--]", "[ -]"],
100 success="[==]",
101 fail="[xx]",
102 ),
103 bouncing_ball=SpinnerConfig(
104 working=[
105 "( ● )",
106 "( ● )",
107 "( ● )",
108 "( ● )",
109 "( ●)",
110 "( ● )",
111 "( ● )",
112 "( ● )",
113 "( ● )",
114 "(● )",
115 ],
116 success="(●●●●●●)",
117 fail="( ✖ )",
118 ),
119 ooo=SpinnerConfig(working=[".", "o", "O", "o"], success="O", fail="x"),
120 braille=SpinnerConfig(
121 working=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
122 success="⣿",
123 fail="X",
124 ),
125 clock=SpinnerConfig(
126 working=[
127 "🕛",
128 "🕐",
129 "🕑",
130 "🕒",
131 "🕓",
132 "🕔",
133 "🕕",
134 "🕖",
135 "🕗",
136 "🕘",
137 "🕙",
138 "🕚",
139 ],
140 success="✔️",
141 fail="❌",
142 ),
143 hourglass=SpinnerConfig(working=["⏳", "⌛"], success="✔️", fail="❌"),
144 square_corners=SpinnerConfig(working=["◰", "◳", "◲", "◱"], success="◼", fail="✖"),
145 triangle=SpinnerConfig(working=["◢", "◣", "◤", "◥"], success="◆", fail="✖"),
146 square_dot=SpinnerConfig(
147 working=["⣷", "⣯", "⣟", "⡿", "⢿", "⣻", "⣽", "⣾"], success="⣿", fail="❌"
148 ),
149 box_bounce=SpinnerConfig(working=["▌", "▀", "▐", "▄"], success="■", fail="✖"),
150 hamburger=SpinnerConfig(working=["☱", "☲", "☴"], success="☰", fail="✖"),
151 earth=SpinnerConfig(working=["🌍", "🌎", "🌏"], success="✔️", fail="❌"),
152 growing_dots=SpinnerConfig(
153 working=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"], success="⣿", fail="✖"
154 ),
155 dice=SpinnerConfig(working=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"], success="🎲", fail="✖"),
156 wifi=SpinnerConfig(
157 working=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"], success="✔️", fail="❌"
158 ),
159 bounce=SpinnerConfig(working=["⠁", "⠂", "⠄", "⠂"], success="⠿", fail="⢿"),
160 arc=SpinnerConfig(working=["◜", "◠", "◝", "◞", "◡", "◟"], success="○", fail="✖"),
161 toggle=SpinnerConfig(working=["⊶", "⊷"], success="⊷", fail="⊗"),
162 toggle2=SpinnerConfig(working=["▫", "▪"], success="▪", fail="✖"),
163 toggle3=SpinnerConfig(working=["□", "■"], success="■", fail="✖"),
164 toggle4=SpinnerConfig(working=["■", "□", "▪", "▫"], success="■", fail="✖"),
165 toggle5=SpinnerConfig(working=["▮", "▯"], success="▮", fail="✖"),
166 toggle7=SpinnerConfig(working=["⦾", "⦿"], success="⦿", fail="✖"),
167 toggle8=SpinnerConfig(working=["◍", "◌"], success="◍", fail="✖"),
168 toggle9=SpinnerConfig(working=["◉", "◎"], success="◉", fail="✖"),
169 arrow2=SpinnerConfig(
170 working=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "], success="➡️", fail="❌"
171 ),
172 point=SpinnerConfig(
173 working=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"], success="●●●", fail="xxx"
174 ),
175 layer=SpinnerConfig(working=["-", "=", "≡"], success="≡", fail="✖"),
176 speaker=SpinnerConfig(
177 working=["🔈 ", "🔉 ", "🔊 ", "🔉 "], success="🔊", fail="🔇"
178 ),
179 orangePulse=SpinnerConfig(
180 working=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "], success="🟠", fail="❌"
181 ),
182 bluePulse=SpinnerConfig(
183 working=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "], success="🔵", fail="❌"
184 ),
185 satellite_signal=SpinnerConfig(
186 working=["📡 ", "📡· ", "📡·· ", "📡···", "📡 ··", "📡 ·"],
187 success="📡 ✔️ ",
188 fail="📡 ❌ ",
189 ),
190 rocket_orbit=SpinnerConfig(
191 working=["🌍🚀 ", "🌏 🚀 ", "🌎 🚀"], success="🌍 ✨", fail="🌍 💥"
192 ),
193 ogham=SpinnerConfig(working=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"], success="᚛᚜", fail="✖"),
194 eth=SpinnerConfig(
195 working=["᛫", "፡", "፥", "፤", "፧", "።", "፨"], success="፠", fail="✖"
196 ),
197)
198# spinner configurations
201class Spinner:
202 """displays a spinner, and optionally elapsed time and a mutable value while a function is running.
204 # Parameters:
206 - `update_interval : float`
207 how often to update the spinner display in seconds
208 (defaults to `0.1`)
209 - `initial_value : str`
210 initial value to display with the spinner
211 (defaults to `""`)
212 - `message : str`
213 message to display with the spinner
214 (defaults to `""`)
215 - `format_string : str`
216 string to format the spinner with. must have `"\\r"` prepended to clear the line.
217 allowed keys are `spinner`, `elapsed_time`, `message`, and `value`
218 (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`)
219 - `output_stream : TextIO`
220 stream to write the spinner to
221 (defaults to `sys.stdout`)
222 - `format_string_when_updated : Union[bool,str]`
223 whether to use a different format string when the value is updated.
224 if `True`, use the default format string with a newline appended. if a string, use that string.
225 this is useful if you want update_value to print to console and be preserved.
226 (defaults to `False`)
228 # Deprecated Parameters:
230 - `spinner_chars : Union[str, Sequence[str]]`
231 sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters
232 (defaults to `"default"`)
233 - `spinner_complete : str`
234 string to display when the spinner is complete
235 (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`)
237 # Methods:
238 - `update_value(value: Any) -> None`
239 update the current value displayed by the spinner
241 # Usage:
243 ## As a context manager:
244 ```python
245 with SpinnerContext() as sp:
246 for i in range(1):
247 time.sleep(0.1)
248 spinner.update_value(f"Step {i+1}")
249 ```
251 ## As a decorator:
252 ```python
253 @spinner_decorator
254 def long_running_function():
255 for i in range(1):
256 time.sleep(0.1)
257 spinner.update_value(f"Step {i+1}")
258 return "Function completed"
259 ```
260 """
262 def __init__(
263 self,
264 # no positional args
265 *args,
266 config: SpinnerConfigArg = "default",
267 update_interval: float = 0.1,
268 initial_value: str = "",
269 message: str = "",
270 format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}",
271 output_stream: TextIO = sys.stdout,
272 format_string_when_updated: Union[str, bool] = False,
273 # deprecated
274 spinner_chars: Optional[Union[str, Sequence[str]]] = None,
275 spinner_complete: Optional[str] = None,
276 # no other kwargs accepted
277 **kwargs: Any,
278 ):
279 if args:
280 raise ValueError(f"Spinner does not accept positional arguments: {args}")
281 if kwargs:
282 raise ValueError(
283 f"Spinner did not recognize these keyword arguments: {kwargs}"
284 )
286 # old spinner display
287 if (spinner_chars is not None) or (spinner_complete is not None):
288 warnings.warn(
289 "spinner_chars and spinner_complete are deprecated and will have no effect. Use `config` instead.",
290 DeprecationWarning,
291 )
293 # config
294 self.config: SpinnerConfig = SpinnerConfig.from_any(config)
296 # special format string for when the value is updated
297 self.format_string_when_updated: Optional[str] = None
298 "format string to use when the value is updated"
299 if format_string_when_updated is not False:
300 if format_string_when_updated is True:
301 # modify the default format string
302 self.format_string_when_updated = format_string + "\n"
303 elif isinstance(format_string_when_updated, str):
304 # use the provided format string
305 self.format_string_when_updated = format_string_when_updated
306 else:
307 raise TypeError(
308 "format_string_when_updated must be a string or True, got"
309 + f" {type(format_string_when_updated) = }{format_string_when_updated}"
310 )
312 # copy other kwargs
313 self.update_interval: float = update_interval
314 self.message: str = message
315 self.current_value: Any = initial_value
316 self.format_string: str = format_string
317 self.output_stream: TextIO = output_stream
319 # test out format string
320 try:
321 self.format_string.format(
322 spinner=self.config.working[0],
323 elapsed_time=0.0,
324 message=self.message,
325 value=self.current_value,
326 )
327 except Exception as e:
328 raise ValueError(
329 f"Invalid format string: {format_string}. Must take keys "
330 + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'."
331 ) from e
333 # init
334 self.start_time: float = 0
335 "for measuring elapsed time"
336 self.stop_spinner: threading.Event = threading.Event()
337 "to stop the spinner"
338 self.spinner_thread: Optional[threading.Thread] = None
339 "the thread running the spinner"
340 self.value_changed: bool = False
341 "whether the value has been updated since the last display"
342 self.term_width: int
343 "width of the terminal, for padding with spaces"
344 try:
345 self.term_width = os.get_terminal_size().columns
346 except OSError:
347 self.term_width = 80
349 # state of the spinner
350 self.state: Literal["initialized", "running", "success", "fail"] = "initialized"
352 def spin(self) -> None:
353 "Function to run in a separate thread, displaying the spinner and optional information"
354 i: int = 0
355 while not self.stop_spinner.is_set():
356 # get current spinner str
357 spinner: str = self.config.working[i % len(self.config.working)]
359 # args for display string
360 display_parts: Dict[str, Any] = dict(
361 spinner=spinner, # str
362 elapsed_time=time.time() - self.start_time, # float
363 message=self.message, # str
364 value=self.current_value, # Any, but will be formatted as str
365 )
367 # use the special one if needed
368 format_str: str = self.format_string
369 if self.value_changed and (self.format_string_when_updated is not None):
370 self.value_changed = False
371 format_str = self.format_string_when_updated
373 # write and flush the display string
374 output: str = format_str.format(**display_parts).ljust(self.term_width)
375 self.output_stream.write(output)
376 self.output_stream.flush()
378 # wait for the next update
379 time.sleep(self.update_interval)
380 i += 1
382 def update_value(self, value: Any) -> None:
383 "Update the current value displayed by the spinner"
384 self.current_value = value
385 self.value_changed = True
387 def start(self) -> None:
388 "Start the spinner"
389 self.start_time = time.time()
390 self.spinner_thread = threading.Thread(target=self.spin)
391 self.spinner_thread.start()
392 self.state = "running"
394 def stop(self, failed: bool = False) -> None:
395 "Stop the spinner"
396 self.output_stream.write(
397 self.format_string.format(
398 spinner=self.config.success if not failed else self.config.fail,
399 elapsed_time=time.time() - self.start_time, # float
400 message=self.message, # str
401 value=self.current_value, # Any, but will be formatted as str
402 ).ljust(self.term_width)
403 )
404 self.stop_spinner.set()
405 if self.spinner_thread:
406 self.spinner_thread.join()
407 self.output_stream.write("\n")
408 self.output_stream.flush()
410 self.state = "fail" if failed else "success"
413class NoOpContextManager(ContextManager):
414 """A context manager that does nothing."""
416 def __init__(self, *args, **kwargs):
417 pass
419 def __enter__(self):
420 return self
422 def __exit__(self, exc_type, exc_value, traceback):
423 pass
426class SpinnerContext(Spinner, ContextManager):
427 "see `Spinner` for parameters"
429 def __enter__(self) -> "SpinnerContext":
430 self.start()
431 return self
433 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
434 self.stop(failed=exc_type is not None)
437SpinnerContext.__doc__ = Spinner.__doc__
440# TODO: type hint that the `update_status` kwarg is not needed when calling the function we just decorated
441def spinner_decorator(
442 *args,
443 # passed to `Spinner.__init__`
444 config: SpinnerConfigArg = "default",
445 update_interval: float = 0.1,
446 initial_value: str = "",
447 message: str = "",
448 format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}",
449 output_stream: TextIO = sys.stdout,
450 # new kwarg
451 mutable_kwarg_key: Optional[str] = None,
452 # deprecated
453 spinner_chars: Union[str, Sequence[str], None] = None,
454 spinner_complete: Optional[str] = None,
455 **kwargs,
456) -> Callable[[DecoratedFunction], DecoratedFunction]:
457 """see `Spinner` for parameters. Also takes `mutable_kwarg_key`
459 `mutable_kwarg_key` is the key with which `Spinner().update_value`
460 will be passed to the decorated function. if `None`, won't pass it.
462 """
464 if len(args) > 1:
465 raise ValueError(
466 f"spinner_decorator does not accept positional arguments: {args}"
467 )
468 if kwargs:
469 raise ValueError(
470 f"spinner_decorator did not recognize these keyword arguments: {kwargs}"
471 )
473 def decorator(func: DecoratedFunction) -> DecoratedFunction:
474 @wraps(func)
475 def wrapper(*args: Any, **kwargs: Any) -> Any:
476 spinner: Spinner = Spinner(
477 config=config,
478 update_interval=update_interval,
479 initial_value=initial_value,
480 message=message,
481 format_string=format_string,
482 output_stream=output_stream,
483 spinner_chars=spinner_chars,
484 spinner_complete=spinner_complete,
485 )
487 if mutable_kwarg_key:
488 kwargs[mutable_kwarg_key] = spinner.update_value
490 spinner.start()
491 try:
492 result: Any = func(*args, **kwargs)
493 spinner.stop(failed=False)
494 except Exception as e:
495 spinner.stop(failed=True)
496 raise e
498 return result
500 # TODO: fix this type ignore
501 return wrapper # type: ignore[return-value]
503 if not args:
504 # called as `@spinner_decorator(stuff)`
505 return decorator
506 else:
507 # called as `@spinner_decorator` without parens
508 return decorator(args[0])
511spinner_decorator.__doc__ = Spinner.__doc__