Coverage for muutils/nbutils/configure_notebook.py: 40%
133 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""shared utilities for setting up a notebook"""
3from __future__ import annotations
5import os
6import typing
7import warnings
9import matplotlib.pyplot as plt # type: ignore[import]
12class PlotlyNotInstalledWarning(UserWarning):
13 pass
16# handle plotly importing
17PLOTLY_IMPORTED: bool
18try:
19 import plotly.io as pio # type: ignore[import]
20except ImportError:
21 warnings.warn(
22 "Plotly not installed. Plotly plots will not be available.",
23 PlotlyNotInstalledWarning,
24 )
25 PLOTLY_IMPORTED = False
26else:
27 PLOTLY_IMPORTED = True
29# figure out if we're in a jupyter notebook
30try:
31 from IPython import get_ipython # type: ignore[import-not-found]
33 IN_JUPYTER = get_ipython() is not None
34except ImportError:
35 IN_JUPYTER = False
37# muutils imports
38from muutils.mlutils import get_device, set_reproducibility # noqa: E402
40# handling figures
41PlottingMode = typing.Literal["ignore", "inline", "widget", "save"]
42PLOT_MODE: PlottingMode = "inline"
43CONVERSION_PLOTMODE_OVERRIDE: PlottingMode | None = None
44FIG_COUNTER: int = 0
45FIG_OUTPUT_FMT: str | None = None
46FIG_NUMBERED_FNAME: str = "figure-{num}"
47FIG_CONFIG: dict | None = None
48FIG_BASEPATH: str | None = None
49CLOSE_AFTER_PLOTSHOW: bool = False
51MATPLOTLIB_FORMATS = ["pdf", "png", "jpg", "jpeg", "svg", "eps", "ps", "tif", "tiff"]
52TIKZPLOTLIB_FORMATS = ["tex", "tikz"]
55class UnknownFigureFormatWarning(UserWarning):
56 pass
59def universal_savefig(fname: str, fmt: str | None = None) -> None:
60 # try to infer format from fname
61 if fmt is None:
62 fmt = fname.split(".")[-1]
64 if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS):
65 warnings.warn(
66 f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'",
67 UnknownFigureFormatWarning,
68 )
69 fmt = FIG_OUTPUT_FMT
71 # not sure why linting is throwing an error here
72 if not fname.endswith(fmt): # type: ignore[arg-type]
73 fname += f".{fmt}"
75 if fmt in MATPLOTLIB_FORMATS:
76 plt.savefig(fname, format=fmt, bbox_inches="tight")
77 elif fmt in TIKZPLOTLIB_FORMATS:
78 import tikzplotlib # type: ignore[import]
80 tikzplotlib.save(fname)
81 else:
82 warnings.warn(f"Unknown format '{fmt}', going with matplotlib default")
83 plt.savefig(fname, bbox_inches="tight")
86def setup_plots(
87 plot_mode: PlottingMode = "inline",
88 fig_output_fmt: str | None = "pdf",
89 fig_numbered_fname: str = "figure-{num}",
90 fig_config: dict | None = None,
91 fig_basepath: str | None = None,
92 close_after_plotshow: bool = False,
93) -> None:
94 """Set up plot saving/rendering options"""
95 global \
96 PLOT_MODE, \
97 CONVERSION_PLOTMODE_OVERRIDE, \
98 FIG_COUNTER, \
99 FIG_OUTPUT_FMT, \
100 FIG_NUMBERED_FNAME, \
101 FIG_CONFIG, \
102 FIG_BASEPATH, \
103 CLOSE_AFTER_PLOTSHOW
105 # set plot mode, handling override
106 if CONVERSION_PLOTMODE_OVERRIDE is not None:
107 # override if set
108 PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE
109 else:
110 # otherwise use the given plot mode
111 PLOT_MODE = plot_mode
113 FIG_COUNTER = 0
114 CLOSE_AFTER_PLOTSHOW = close_after_plotshow
116 if PLOT_MODE == "inline":
117 if IN_JUPYTER:
118 ipython = get_ipython()
119 ipython.magic("matplotlib inline")
120 else:
121 raise RuntimeError(
122 f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }"
123 )
124 return
125 elif PLOT_MODE == "widget":
126 if IN_JUPYTER:
127 ipython = get_ipython()
128 ipython.magic("matplotlib widget")
129 else:
130 # matplotlib outside of jupyter will bring up a new window by default
131 pass
132 return
133 elif PLOT_MODE == "ignore":
134 # disable plotting
135 plt.show = lambda: None # type: ignore[misc]
136 return
138 # everything except saving handled up to this point
139 assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}"
141 FIG_OUTPUT_FMT = fig_output_fmt
142 FIG_NUMBERED_FNAME = fig_numbered_fname
143 FIG_CONFIG = fig_config
145 # set default figure format in rcParams savefig.format
146 plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT
147 if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS:
148 try:
149 import tikzplotlib # type: ignore[import] # noqa: F401
150 except ImportError:
151 warnings.warn(
152 f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break."
153 )
154 else:
155 if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS:
156 warnings.warn(
157 f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }'
158 )
160 # if base path not given, make one
161 if fig_basepath is None:
162 if fig_config is None:
163 # if no config, use the current time
164 from datetime import datetime
166 fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
167 else:
168 # if config given, convert to string
169 from muutils.misc import dict_to_filename
171 fig_basepath = f"figures/{dict_to_filename(fig_config)}"
173 FIG_BASEPATH = fig_basepath
174 os.makedirs(fig_basepath, exist_ok=True)
176 # if config given, serialize and save that config
177 if fig_config is not None:
178 import json
180 from muutils.json_serialize import json_serialize
182 with open(f"{fig_basepath}/config.json", "w") as f:
183 json.dump(
184 json_serialize(fig_config),
185 f,
186 indent="\t",
187 )
189 print(f"Figures will be saved to: '{fig_basepath}'")
192def configure_notebook(
193 *args,
194 seed: int = 42,
195 device: typing.Any = None, # this can be a string, torch.device, or None
196 dark_mode: bool = True,
197 plot_mode: PlottingMode = "inline",
198 fig_output_fmt: str | None = "pdf",
199 fig_numbered_fname: str = "figure-{num}",
200 fig_config: dict | None = None,
201 fig_basepath: str | None = None,
202 close_after_plotshow: bool = False,
203) -> "torch.device|None": # type: ignore[name-defined] # noqa: F821
204 """Shared Jupyter notebook setup steps
206 - Set random seeds and library reproducibility settings
207 - Set device based on availability
208 - Set module reloading before code execution
209 - Set plot formatting
210 - Set plot saving/rendering options
212 # Parameters:
213 - `seed : int`
214 random seed across libraries including torch, numpy, and random (defaults to `42`)
215 (defaults to `42`)
216 - `device : typing.Any`
217 pytorch device to use
218 (defaults to `None`)
219 - `dark_mode : bool`
220 figures in dark mode
221 (defaults to `True`)
222 - `plot_mode : PlottingMode`
223 how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]`
224 (defaults to `"inline"`)
225 - `fig_output_fmt : str | None`
226 format for saving figures
227 (defaults to `"pdf"`)
228 - `fig_numbered_fname : str`
229 format for saving figures with numbers (if they aren't named)
230 (defaults to `"figure-{num}"`)
231 - `fig_config : dict | None`
232 metadata to save with the figures
233 (defaults to `None`)
234 - `fig_basepath : str | None`
235 base path for saving figures
236 (defaults to `None`)
237 - `close_after_plotshow : bool`
238 close figures after showing them
239 (defaults to `False`)
241 # Returns:
242 - `torch.device|None`
243 the device set, if torch is installed
244 """
246 # set some globals related to plotting
247 setup_plots(
248 plot_mode=plot_mode,
249 fig_output_fmt=fig_output_fmt,
250 fig_numbered_fname=fig_numbered_fname,
251 fig_config=fig_config,
252 fig_basepath=fig_basepath,
253 close_after_plotshow=close_after_plotshow,
254 )
256 global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH
258 print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }")
260 # Set seeds and other reproducibility-related library options
261 set_reproducibility(seed)
263 # Reload modules before executing user code
264 if IN_JUPYTER:
265 ipython = get_ipython()
266 if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded:
267 ipython.magic("load_ext autoreload")
268 ipython.magic("autoreload 2")
270 # Specify plotly renderer for vscode
271 if PLOTLY_IMPORTED:
272 pio.renderers.default = "notebook_connected"
274 if dark_mode:
275 pio.templates.default = "plotly_dark"
276 plt.style.use("dark_background")
278 try:
279 # Set device
280 device = get_device(device)
281 return device
282 except ImportError:
283 warnings.warn("Torch not installed. Cannot get/set device.")
284 return None
287def plotshow(
288 fname: str | None = None,
289 plot_mode: PlottingMode | None = None,
290 fmt: str | None = None,
291):
292 """Show the active plot, depending on global configs"""
293 global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE
294 FIG_COUNTER += 1
296 if plot_mode is None:
297 plot_mode = PLOT_MODE
299 if plot_mode == "save":
300 # get numbered figure name if not given
301 if fname is None:
302 fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER)
304 # save figure
305 assert FIG_BASEPATH is not None
306 universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt)
307 elif plot_mode == "ignore":
308 # do nothing
309 pass
310 elif plot_mode == "inline":
311 # show figure
312 plt.show()
313 elif plot_mode == "widget":
314 # show figure
315 plt.show()
316 else:
317 warnings.warn(f"Invalid plot mode: {plot_mode}")
319 if CLOSE_AFTER_PLOTSHOW:
320 plt.close()