Coverage for muutils/nbutils/configure_notebook.py: 40%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""shared utilities for setting up a notebook""" 

2 

3from __future__ import annotations 

4 

5import os 

6import typing 

7import warnings 

8 

9import matplotlib.pyplot as plt # type: ignore[import] 

10 

11 

12class PlotlyNotInstalledWarning(UserWarning): 

13 pass 

14 

15 

16# handle plotly importing 

17PLOTLY_IMPORTED: bool 

18try: 

19 import plotly.io as pio # type: ignore[import] 

20except ImportError: 

21 warnings.warn( 

22 "Plotly not installed. Plotly plots will not be available.", 

23 PlotlyNotInstalledWarning, 

24 ) 

25 PLOTLY_IMPORTED = False 

26else: 

27 PLOTLY_IMPORTED = True 

28 

29# figure out if we're in a jupyter notebook 

30try: 

31 from IPython import get_ipython # type: ignore[import-not-found] 

32 

33 IN_JUPYTER = get_ipython() is not None 

34except ImportError: 

35 IN_JUPYTER = False 

36 

37# muutils imports 

38from muutils.mlutils import get_device, set_reproducibility # noqa: E402 

39 

40# handling figures 

41PlottingMode = typing.Literal["ignore", "inline", "widget", "save"] 

42PLOT_MODE: PlottingMode = "inline" 

43CONVERSION_PLOTMODE_OVERRIDE: PlottingMode | None = None 

44FIG_COUNTER: int = 0 

45FIG_OUTPUT_FMT: str | None = None 

46FIG_NUMBERED_FNAME: str = "figure-{num}" 

47FIG_CONFIG: dict | None = None 

48FIG_BASEPATH: str | None = None 

49CLOSE_AFTER_PLOTSHOW: bool = False 

50 

51MATPLOTLIB_FORMATS = ["pdf", "png", "jpg", "jpeg", "svg", "eps", "ps", "tif", "tiff"] 

52TIKZPLOTLIB_FORMATS = ["tex", "tikz"] 

53 

54 

55class UnknownFigureFormatWarning(UserWarning): 

56 pass 

57 

58 

59def universal_savefig(fname: str, fmt: str | None = None) -> None: 

60 # try to infer format from fname 

61 if fmt is None: 

62 fmt = fname.split(".")[-1] 

63 

64 if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS): 

65 warnings.warn( 

66 f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'", 

67 UnknownFigureFormatWarning, 

68 ) 

69 fmt = FIG_OUTPUT_FMT 

70 

71 # not sure why linting is throwing an error here 

72 if not fname.endswith(fmt): # type: ignore[arg-type] 

73 fname += f".{fmt}" 

74 

75 if fmt in MATPLOTLIB_FORMATS: 

76 plt.savefig(fname, format=fmt, bbox_inches="tight") 

77 elif fmt in TIKZPLOTLIB_FORMATS: 

78 import tikzplotlib # type: ignore[import] 

79 

80 tikzplotlib.save(fname) 

81 else: 

82 warnings.warn(f"Unknown format '{fmt}', going with matplotlib default") 

83 plt.savefig(fname, bbox_inches="tight") 

84 

85 

86def setup_plots( 

87 plot_mode: PlottingMode = "inline", 

88 fig_output_fmt: str | None = "pdf", 

89 fig_numbered_fname: str = "figure-{num}", 

90 fig_config: dict | None = None, 

91 fig_basepath: str | None = None, 

92 close_after_plotshow: bool = False, 

93) -> None: 

94 """Set up plot saving/rendering options""" 

95 global \ 

96 PLOT_MODE, \ 

97 CONVERSION_PLOTMODE_OVERRIDE, \ 

98 FIG_COUNTER, \ 

99 FIG_OUTPUT_FMT, \ 

100 FIG_NUMBERED_FNAME, \ 

101 FIG_CONFIG, \ 

102 FIG_BASEPATH, \ 

103 CLOSE_AFTER_PLOTSHOW 

104 

105 # set plot mode, handling override 

106 if CONVERSION_PLOTMODE_OVERRIDE is not None: 

107 # override if set 

108 PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE 

109 else: 

110 # otherwise use the given plot mode 

111 PLOT_MODE = plot_mode 

112 

113 FIG_COUNTER = 0 

114 CLOSE_AFTER_PLOTSHOW = close_after_plotshow 

115 

116 if PLOT_MODE == "inline": 

117 if IN_JUPYTER: 

118 ipython = get_ipython() 

119 ipython.magic("matplotlib inline") 

120 else: 

121 raise RuntimeError( 

122 f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }" 

123 ) 

124 return 

125 elif PLOT_MODE == "widget": 

126 if IN_JUPYTER: 

127 ipython = get_ipython() 

128 ipython.magic("matplotlib widget") 

129 else: 

130 # matplotlib outside of jupyter will bring up a new window by default 

131 pass 

132 return 

133 elif PLOT_MODE == "ignore": 

134 # disable plotting 

135 plt.show = lambda: None # type: ignore[misc] 

136 return 

137 

138 # everything except saving handled up to this point 

139 assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}" 

140 

141 FIG_OUTPUT_FMT = fig_output_fmt 

142 FIG_NUMBERED_FNAME = fig_numbered_fname 

143 FIG_CONFIG = fig_config 

144 

145 # set default figure format in rcParams savefig.format 

146 plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT 

147 if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS: 

148 try: 

149 import tikzplotlib # type: ignore[import] # noqa: F401 

150 except ImportError: 

151 warnings.warn( 

152 f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break." 

153 ) 

154 else: 

155 if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS: 

156 warnings.warn( 

157 f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }' 

158 ) 

159 

160 # if base path not given, make one 

161 if fig_basepath is None: 

162 if fig_config is None: 

163 # if no config, use the current time 

164 from datetime import datetime 

165 

166 fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" 

167 else: 

168 # if config given, convert to string 

169 from muutils.misc import dict_to_filename 

170 

171 fig_basepath = f"figures/{dict_to_filename(fig_config)}" 

172 

173 FIG_BASEPATH = fig_basepath 

174 os.makedirs(fig_basepath, exist_ok=True) 

175 

176 # if config given, serialize and save that config 

177 if fig_config is not None: 

178 import json 

179 

180 from muutils.json_serialize import json_serialize 

181 

182 with open(f"{fig_basepath}/config.json", "w") as f: 

183 json.dump( 

184 json_serialize(fig_config), 

185 f, 

186 indent="\t", 

187 ) 

188 

189 print(f"Figures will be saved to: '{fig_basepath}'") 

190 

191 

192def configure_notebook( 

193 *args, 

194 seed: int = 42, 

195 device: typing.Any = None, # this can be a string, torch.device, or None 

196 dark_mode: bool = True, 

197 plot_mode: PlottingMode = "inline", 

198 fig_output_fmt: str | None = "pdf", 

199 fig_numbered_fname: str = "figure-{num}", 

200 fig_config: dict | None = None, 

201 fig_basepath: str | None = None, 

202 close_after_plotshow: bool = False, 

203) -> "torch.device|None": # type: ignore[name-defined] # noqa: F821 

204 """Shared Jupyter notebook setup steps 

205 

206 - Set random seeds and library reproducibility settings 

207 - Set device based on availability 

208 - Set module reloading before code execution 

209 - Set plot formatting 

210 - Set plot saving/rendering options 

211 

212 # Parameters: 

213 - `seed : int` 

214 random seed across libraries including torch, numpy, and random (defaults to `42`) 

215 (defaults to `42`) 

216 - `device : typing.Any` 

217 pytorch device to use 

218 (defaults to `None`) 

219 - `dark_mode : bool` 

220 figures in dark mode 

221 (defaults to `True`) 

222 - `plot_mode : PlottingMode` 

223 how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]` 

224 (defaults to `"inline"`) 

225 - `fig_output_fmt : str | None` 

226 format for saving figures 

227 (defaults to `"pdf"`) 

228 - `fig_numbered_fname : str` 

229 format for saving figures with numbers (if they aren't named) 

230 (defaults to `"figure-{num}"`) 

231 - `fig_config : dict | None` 

232 metadata to save with the figures 

233 (defaults to `None`) 

234 - `fig_basepath : str | None` 

235 base path for saving figures 

236 (defaults to `None`) 

237 - `close_after_plotshow : bool` 

238 close figures after showing them 

239 (defaults to `False`) 

240 

241 # Returns: 

242 - `torch.device|None` 

243 the device set, if torch is installed 

244 """ 

245 

246 # set some globals related to plotting 

247 setup_plots( 

248 plot_mode=plot_mode, 

249 fig_output_fmt=fig_output_fmt, 

250 fig_numbered_fname=fig_numbered_fname, 

251 fig_config=fig_config, 

252 fig_basepath=fig_basepath, 

253 close_after_plotshow=close_after_plotshow, 

254 ) 

255 

256 global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH 

257 

258 print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }") 

259 

260 # Set seeds and other reproducibility-related library options 

261 set_reproducibility(seed) 

262 

263 # Reload modules before executing user code 

264 if IN_JUPYTER: 

265 ipython = get_ipython() 

266 if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded: 

267 ipython.magic("load_ext autoreload") 

268 ipython.magic("autoreload 2") 

269 

270 # Specify plotly renderer for vscode 

271 if PLOTLY_IMPORTED: 

272 pio.renderers.default = "notebook_connected" 

273 

274 if dark_mode: 

275 pio.templates.default = "plotly_dark" 

276 plt.style.use("dark_background") 

277 

278 try: 

279 # Set device 

280 device = get_device(device) 

281 return device 

282 except ImportError: 

283 warnings.warn("Torch not installed. Cannot get/set device.") 

284 return None 

285 

286 

287def plotshow( 

288 fname: str | None = None, 

289 plot_mode: PlottingMode | None = None, 

290 fmt: str | None = None, 

291): 

292 """Show the active plot, depending on global configs""" 

293 global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE 

294 FIG_COUNTER += 1 

295 

296 if plot_mode is None: 

297 plot_mode = PLOT_MODE 

298 

299 if plot_mode == "save": 

300 # get numbered figure name if not given 

301 if fname is None: 

302 fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER) 

303 

304 # save figure 

305 assert FIG_BASEPATH is not None 

306 universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt) 

307 elif plot_mode == "ignore": 

308 # do nothing 

309 pass 

310 elif plot_mode == "inline": 

311 # show figure 

312 plt.show() 

313 elif plot_mode == "widget": 

314 # show figure 

315 plt.show() 

316 else: 

317 warnings.warn(f"Invalid plot mode: {plot_mode}") 

318 

319 if CLOSE_AFTER_PLOTSHOW: 

320 plt.close()