Coverage for tests/unit/nbutils/test_configure_notebook.py: 100%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1import os 

2import warnings 

3 

4import matplotlib.pyplot as plt # type: ignore[import] 

5import pytest 

6import torch 

7 

8from muutils.nbutils.configure_notebook import ( 

9 UnknownFigureFormatWarning, 

10 configure_notebook, 

11 plotshow, 

12 setup_plots, 

13) 

14 

15JUNK_DATA_PATH: str = "tests/_temp/test_cfg_notebook" 

16 

17 

18@pytest.mark.parametrize( 

19 "plot_mode", 

20 [ 

21 # "inline", # cant use outside a jupyter notebook 

22 "widget", 

23 "ignore", 

24 ], 

25) 

26def test_setup_plots_donothing(plot_mode): 

27 setup_plots(plot_mode=plot_mode) 

28 

29 

30def test_no_inline_outside_nb(): 

31 with pytest.raises(RuntimeError): 

32 configure_notebook(plot_mode="inline") 

33 

34 

35def test_setup_plots_save(): 

36 setup_plots(plot_mode="save", fig_basepath=JUNK_DATA_PATH) 

37 assert os.path.exists(JUNK_DATA_PATH) 

38 

39 

40def test_configure_notebook(): 

41 device = configure_notebook(seed=42, plot_mode="ignore") 

42 assert isinstance(device, torch.device) # Assumes 'torch' is imported 

43 

44 

45def test_plotshow_save(): 

46 setup_plots(plot_mode="save", fig_basepath=JUNK_DATA_PATH) 

47 with pytest.warns(UnknownFigureFormatWarning): 

48 plt.plot([1, 2, 3], [1, 2, 3]) 

49 plotshow() 

50 assert os.path.exists(os.path.join(JUNK_DATA_PATH, "figure-1.pdf")) 

51 with pytest.warns(UnknownFigureFormatWarning): 

52 plt.plot([3, 6, 9], [2, 4, 8]) 

53 plotshow() 

54 assert os.path.exists(os.path.join(JUNK_DATA_PATH, "figure-2.pdf")) 

55 

56 

57def test_plotshow_save_named(): 

58 setup_plots(plot_mode="save", fig_basepath=JUNK_DATA_PATH) 

59 plt.plot([1, 2, 3], [1, 2, 3]) 

60 plotshow(fname="test.pdf") 

61 assert os.path.exists(os.path.join(JUNK_DATA_PATH, "test.pdf")) 

62 plt.plot([3, 6, 9], [2, 4, 8]) 

63 plotshow(fname="another-test.pdf") 

64 assert os.path.exists(os.path.join(JUNK_DATA_PATH, "another-test.pdf")) 

65 

66 

67def test_plotshow_save_mixed(): 

68 setup_plots( 

69 plot_mode="save", 

70 fig_basepath=JUNK_DATA_PATH, 

71 fig_numbered_fname="mixedfig-{num}", 

72 ) 

73 with pytest.warns(UnknownFigureFormatWarning): 

74 plt.plot([1, 2, 3], [1, 2, 3]) 

75 plotshow() 

76 assert os.path.exists(os.path.join(JUNK_DATA_PATH, "mixedfig-1.pdf")) 

77 plt.plot([3, 6, 9], [2, 4, 8]) 

78 plotshow(fname="mixed-test.pdf") 

79 assert os.path.exists(os.path.join(JUNK_DATA_PATH, "mixed-test.pdf")) 

80 with pytest.warns(UnknownFigureFormatWarning): 

81 plt.plot([1, 1, 1], [1, 9, 9]) 

82 plotshow() 

83 assert os.path.exists(os.path.join(JUNK_DATA_PATH, "mixedfig-3.pdf")) 

84 

85 

86def test_warn_unknown_format(): 

87 with pytest.warns(UnknownFigureFormatWarning): 

88 setup_plots( 

89 plot_mode="save", 

90 fig_basepath=JUNK_DATA_PATH, 

91 fig_numbered_fname="mixedfig-{num}", 

92 ) 

93 plt.plot([1, 2, 3], [1, 2, 3]) 

94 plotshow() 

95 

96 

97def test_no_warn_unknown_format_2(): 

98 with pytest.warns(UnknownFigureFormatWarning): 

99 setup_plots( 

100 plot_mode="save", 

101 fig_basepath=JUNK_DATA_PATH, 

102 fig_numbered_fname="mixedfig-{num}", 

103 ) 

104 plt.plot([1, 2, 3], [1, 2, 3]) 

105 plotshow("no-format") 

106 

107 

108def test_no_warn_pdf_format(): 

109 with warnings.catch_warnings(): 

110 warnings.simplefilter("error") 

111 setup_plots( 

112 plot_mode="save", 

113 fig_basepath="JUNK_DATA_PATH", 

114 fig_numbered_fname="fig-{num}.pdf", 

115 ) 

116 plt.plot([1, 2, 3], [1, 2, 3]) 

117 plotshow() 

118 

119 

120def test_plotshow_ignore(): 

121 setup_plots(plot_mode="ignore") 

122 plt.plot([1, 2, 3], [1, 2, 3]) 

123 plotshow() 

124 # this should do nothing