Coverage for muutils/nbutils/convert_ipynb_to_script.py: 67%

125 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.""" 

2 

3from __future__ import annotations 

4 

5import argparse 

6import json 

7import os 

8from pathlib import Path 

9import sys 

10import typing 

11import warnings 

12 

13from muutils.spinner import SpinnerContext 

14 

15DISABLE_PLOTS: dict[str, list[str]] = { 

16 "matplotlib": [ 

17 """ 

18# ------------------------------------------------------------ 

19# Disable matplotlib plots, done during processing by `convert_ipynb_to_script.py` 

20import matplotlib.pyplot as plt 

21plt.show = lambda: None 

22# ------------------------------------------------------------ 

23""" 

24 ], 

25 "circuitsvis": [ 

26 """ 

27# ------------------------------------------------------------ 

28# Disable circuitsvis plots, done during processing by `convert_ipynb_to_script.py` 

29from circuitsvis.utils.convert_props import PythonProperty, convert_props 

30from circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local 

31 

32def new_render( 

33 react_element_name: str, 

34 **kwargs: PythonProperty 

35) -> RenderedHTML: 

36 "return a visualization as raw HTML" 

37 local_src = render_local(react_element_name, **kwargs) 

38 cdn_src = render_cdn(react_element_name, **kwargs) 

39 # return as string instead of RenderedHTML for CI 

40 return str(RenderedHTML(local_src, cdn_src)) 

41 

42render = new_render 

43# ------------------------------------------------------------ 

44""" 

45 ], 

46 "muutils": [ 

47 """import muutils.nbutils.configure_notebook as nb_conf 

48nb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore" 

49""" 

50 ], 

51} 

52 

53DISABLE_PLOTS_WARNING: list[str] = [ 

54 """ 

55# ------------------------------------------------------------ 

56# WARNING: this script is auto-generated by `convert_ipynb_to_script.py` 

57# showing plots has been disabled, so this is presumably in a temp dict for CI or something 

58# so don't modify this code, it will be overwritten! 

59# ------------------------------------------------------------ 

60""".lstrip() 

61] 

62 

63 

64def disable_plots_in_script(script_lines: list[str]) -> list[str]: 

65 """Disable plots in a script by adding cursed things after the import statements""" 

66 result_str_TEMP: str = "\n\n".join(script_lines) 

67 script_lines_new: list[str] = script_lines 

68 

69 if "muutils" in result_str_TEMP: 

70 script_lines_new = DISABLE_PLOTS["muutils"] + script_lines_new 

71 

72 if "matplotlib" in result_str_TEMP: 

73 assert ( 

74 "import matplotlib.pyplot as plt" in result_str_TEMP 

75 ), "matplotlib.pyplot must be imported as plt" 

76 

77 # find the last import statement involving matplotlib, and the first line that uses plt 

78 mpl_last_import_index: int = -1 

79 mpl_first_usage_index: int = -1 

80 for i, line in enumerate(script_lines_new): 

81 if "matplotlib" in line and (("import" in line) or ("from" in line)): 

82 mpl_last_import_index = i 

83 

84 if "configure_notebook" in line: 

85 mpl_last_import_index = i 

86 

87 if "plt." in line: 

88 mpl_first_usage_index = i 

89 

90 assert ( 

91 mpl_last_import_index != -1 

92 ), f"matplotlib imports not found! see line {mpl_last_import_index}" 

93 if mpl_first_usage_index != -1: 

94 assert ( 

95 mpl_first_usage_index > mpl_last_import_index 

96 ), f"matplotlib plots created before import! see lines {mpl_first_usage_index}, {mpl_last_import_index}" 

97 else: 

98 warnings.warn( 

99 "could not find where matplotlib is used, plot disabling might not work!" 

100 ) 

101 

102 # insert the cursed things 

103 script_lines_new = ( 

104 script_lines_new[: mpl_last_import_index + 1] 

105 + DISABLE_PLOTS["matplotlib"] 

106 + script_lines_new[mpl_last_import_index + 1 :] 

107 ) 

108 result_str_TEMP = "\n\n".join(script_lines_new) 

109 

110 if "circuitsvis" in result_str_TEMP: 

111 # find the last import statement involving circuitsvis, and the first line that uses it 

112 cirv_last_import_index: int = -1 

113 cirv_first_usage_index: int = -1 

114 

115 for i, line in enumerate(script_lines_new): 

116 if "circuitsvis" in line: 

117 if (("import" in line) or ("from" in line)) and "circuitsvis" in line: 

118 cirv_last_import_index = i 

119 else: 

120 cirv_first_usage_index = i 

121 

122 if "configure_notebook" in line: 

123 mpl_last_import_index = i 

124 

125 if "render" in line: 

126 cirv_first_usage_index = i 

127 

128 assert ( 

129 cirv_last_import_index != -1 

130 ), f"circuitsvis imports not found! see line {cirv_last_import_index}" 

131 if cirv_first_usage_index != -1: 

132 assert ( 

133 cirv_first_usage_index > cirv_last_import_index 

134 ), f"circuitsvis plots created before import! see lines {cirv_first_usage_index}, {cirv_last_import_index}" 

135 else: 

136 warnings.warn( 

137 "could not find where circuitsvis is used, plot disabling might not work!" 

138 ) 

139 

140 # insert the cursed things 

141 script_lines_new = ( 

142 script_lines_new[: cirv_last_import_index + 1] 

143 + DISABLE_PLOTS["circuitsvis"] 

144 + script_lines_new[cirv_last_import_index + 1 :] 

145 ) 

146 result_str_TEMP = "\n\n".join(script_lines_new) 

147 

148 return script_lines_new 

149 

150 

151def convert_ipynb( 

152 notebook: dict, 

153 strip_md_cells: bool = False, 

154 header_comment: str = r"#%%", 

155 disable_plots: bool = False, 

156 filter_out_lines: str | typing.Sequence[str] = ( 

157 "%", 

158 "!", 

159 ), # ignore notebook magic commands and shell commands 

160) -> str: 

161 """Convert Jupyter Notebook to a script, doing some basic filtering and formatting. 

162 

163 # Arguments 

164 - `notebook: dict`: Jupyter Notebook loaded as json. 

165 - `strip_md_cells: bool = False`: Remove markdown cells from the output script. 

166 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script. 

167 - `disable_plots: bool = False`: Disable plots in the output script. 

168 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks). 

169 if a string is passed, it will be split by char and each char will be treated as a separate filter. 

170 

171 # Returns 

172 - `str`: Converted script. 

173 """ 

174 

175 if isinstance(filter_out_lines, str): 

176 filter_out_lines = tuple(filter_out_lines) 

177 filter_out_lines_set: set = set(filter_out_lines) 

178 

179 result: list[str] = [] 

180 

181 all_cells: list[dict] = notebook["cells"] 

182 

183 for cell in all_cells: 

184 cell_type: str = cell["cell_type"] 

185 

186 if not strip_md_cells and cell_type == "markdown": 

187 result.append(f'{header_comment}\n"""\n{"".join(cell["source"])}\n"""') 

188 elif cell_type == "code": 

189 source: list[str] = cell["source"] 

190 if filter_out_lines: 

191 source = [ 

192 ( 

193 f"#{line}" 

194 if any( 

195 line.startswith(filter_prefix) 

196 for filter_prefix in filter_out_lines_set 

197 ) 

198 else line 

199 ) 

200 for line in source 

201 ] 

202 result.append(f'{header_comment}\n{"".join(source)}') 

203 

204 if disable_plots: 

205 result = disable_plots_in_script(result) 

206 result = DISABLE_PLOTS_WARNING + result 

207 

208 return "\n\n".join(result) 

209 

210 

211def process_file( 

212 in_file: str, 

213 out_file: str | None = None, 

214 strip_md_cells: bool = False, 

215 header_comment: str = r"#%%", 

216 disable_plots: bool = False, 

217 filter_out_lines: str | typing.Sequence[str] = ("%", "!"), 

218): 

219 print(f"\tProcessing {in_file}...", file=sys.stderr) 

220 assert os.path.exists(in_file), f"File {in_file} does not exist." 

221 assert os.path.isfile(in_file), f"Path {in_file} is not a file." 

222 assert in_file.endswith(".ipynb"), f"File {in_file} is not a Jupyter Notebook." 

223 

224 with open(in_file, "r") as file: 

225 notebook: dict = json.load(file) 

226 

227 try: 

228 converted_script: str = convert_ipynb( 

229 notebook=notebook, 

230 strip_md_cells=strip_md_cells, 

231 header_comment=header_comment, 

232 disable_plots=disable_plots, 

233 filter_out_lines=filter_out_lines, 

234 ) 

235 except AssertionError as e: 

236 print(f"Error converting {in_file}: {e}", file=sys.stderr) 

237 raise e 

238 

239 if out_file: 

240 with open(out_file, "w") as file: 

241 file.write(converted_script) 

242 else: 

243 print(converted_script) 

244 

245 

246def process_dir( 

247 input_dir: typing.Union[str, Path], 

248 output_dir: typing.Union[str, Path], 

249 strip_md_cells: bool = False, 

250 header_comment: str = r"#%%", 

251 disable_plots: bool = False, 

252 filter_out_lines: str | typing.Sequence[str] = ("%", "!"), 

253): 

254 """Convert all Jupyter Notebooks in a directory to scripts. 

255 

256 # Arguments 

257 - `input_dir: str`: Input directory. 

258 - `output_dir: str`: Output directory. 

259 - `strip_md_cells: bool = False`: Remove markdown cells from the output script. 

260 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script. 

261 - `disable_plots: bool = False`: Disable plots in the output script. 

262 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks). 

263 if a string is passed, it will be split by char and each char will be treated as a separate filter. 

264 """ 

265 

266 assert os.path.exists(input_dir), f"Directory {input_dir} does not exist." 

267 assert os.path.isdir(input_dir), f"Path {input_dir} is not a directory." 

268 

269 if not os.path.exists(output_dir): 

270 os.makedirs(output_dir, exist_ok=True) 

271 

272 filenames: list[str] = [ 

273 fname for fname in os.listdir(input_dir) if fname.endswith(".ipynb") 

274 ] 

275 

276 assert filenames, f"Directory {input_dir} does not contain any Jupyter Notebooks." 

277 n_files: int = len(filenames) 

278 print(f"Converting {n_files} notebooks:", file=sys.stderr) 

279 

280 with SpinnerContext( 

281 spinner_chars="braille", 

282 update_interval=0.01, 

283 format_string_when_updated=True, 

284 output_stream=sys.stderr, 

285 ) as spinner: 

286 for idx, fname in enumerate(filenames): 

287 spinner.update_value(f"\tConverting {idx+1}/{n_files}: {fname}") 

288 in_file: str = os.path.join(input_dir, fname) 

289 out_file: str = os.path.join(output_dir, fname.replace(".ipynb", ".py")) 

290 

291 with open(in_file, "r", encoding="utf-8") as file_in: 

292 notebook: dict = json.load(file_in) 

293 

294 try: 

295 converted_script: str = convert_ipynb( 

296 notebook=notebook, 

297 strip_md_cells=strip_md_cells, 

298 header_comment=header_comment, 

299 disable_plots=disable_plots, 

300 filter_out_lines=filter_out_lines, 

301 ) 

302 except AssertionError as e: 

303 spinner.stop() 

304 raise Exception(f"Error converting {in_file}") from e 

305 

306 with open(out_file, "w", encoding="utf-8") as file_out: 

307 file_out.write(converted_script) 

308 

309 

310if __name__ == "__main__": 

311 parser = argparse.ArgumentParser( 

312 description="Convert Jupyter Notebook to a script with cell separators." 

313 ) 

314 parser.add_argument( 

315 "in_path", 

316 type=str, 

317 help="Input Jupyter Notebook file (.ipynb) or directory of files.", 

318 ) 

319 parser.add_argument( 

320 "--out-file", 

321 type=str, 

322 help="Output script file. If not specified, the result will be printed to stdout.", 

323 ) 

324 parser.add_argument( 

325 "--output-dir", type=str, help="Output directory for converted script files." 

326 ) 

327 parser.add_argument( 

328 "--strip-md-cells", 

329 action="store_true", 

330 help="Remove markdown cells from the output script.", 

331 ) 

332 parser.add_argument( 

333 "--header-comment", 

334 type=str, 

335 default=r"#%%", 

336 help="Comment string to separate cells in the output script.", 

337 ) 

338 parser.add_argument( 

339 "--disable-plots", 

340 action="store_true", 

341 help="Disable plots in the output script. Useful for testing in CI.", 

342 ) 

343 parser.add_argument( 

344 "--filter-out-lines", 

345 type=str, 

346 default="%", 

347 help="Comment out lines starting with these characters.", 

348 ) 

349 

350 args = parser.parse_args() 

351 

352 if args.output_dir: 

353 assert not args.out_file, "Cannot specify both --out_file and --output_dir." 

354 process_dir( 

355 input_dir=args.in_path, 

356 output_dir=args.output_dir, 

357 strip_md_cells=args.strip_md_cells, 

358 header_comment=args.header_comment, 

359 disable_plots=args.disable_plots, 

360 filter_out_lines=args.filter_out_lines, 

361 ) 

362 

363 else: 

364 process_file( 

365 in_file=args.in_path, 

366 out_file=args.out_file, 

367 strip_md_cells=args.strip_md_cells, 

368 header_comment=args.header_comment, 

369 disable_plots=args.disable_plots, 

370 filter_out_lines=args.filter_out_lines, 

371 ) 

372 

373 

374print("muutils.nbutils.convert_ipynb_to_script.py loaded.")