Coverage for muutils/nbutils/convert_ipynb_to_script.py: 67%
125 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting."""
3from __future__ import annotations
5import argparse
6import json
7import os
8from pathlib import Path
9import sys
10import typing
11import warnings
13from muutils.spinner import SpinnerContext
15DISABLE_PLOTS: dict[str, list[str]] = {
16 "matplotlib": [
17 """
18# ------------------------------------------------------------
19# Disable matplotlib plots, done during processing by `convert_ipynb_to_script.py`
20import matplotlib.pyplot as plt
21plt.show = lambda: None
22# ------------------------------------------------------------
23"""
24 ],
25 "circuitsvis": [
26 """
27# ------------------------------------------------------------
28# Disable circuitsvis plots, done during processing by `convert_ipynb_to_script.py`
29from circuitsvis.utils.convert_props import PythonProperty, convert_props
30from circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local
32def new_render(
33 react_element_name: str,
34 **kwargs: PythonProperty
35) -> RenderedHTML:
36 "return a visualization as raw HTML"
37 local_src = render_local(react_element_name, **kwargs)
38 cdn_src = render_cdn(react_element_name, **kwargs)
39 # return as string instead of RenderedHTML for CI
40 return str(RenderedHTML(local_src, cdn_src))
42render = new_render
43# ------------------------------------------------------------
44"""
45 ],
46 "muutils": [
47 """import muutils.nbutils.configure_notebook as nb_conf
48nb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore"
49"""
50 ],
51}
53DISABLE_PLOTS_WARNING: list[str] = [
54 """
55# ------------------------------------------------------------
56# WARNING: this script is auto-generated by `convert_ipynb_to_script.py`
57# showing plots has been disabled, so this is presumably in a temp dict for CI or something
58# so don't modify this code, it will be overwritten!
59# ------------------------------------------------------------
60""".lstrip()
61]
64def disable_plots_in_script(script_lines: list[str]) -> list[str]:
65 """Disable plots in a script by adding cursed things after the import statements"""
66 result_str_TEMP: str = "\n\n".join(script_lines)
67 script_lines_new: list[str] = script_lines
69 if "muutils" in result_str_TEMP:
70 script_lines_new = DISABLE_PLOTS["muutils"] + script_lines_new
72 if "matplotlib" in result_str_TEMP:
73 assert (
74 "import matplotlib.pyplot as plt" in result_str_TEMP
75 ), "matplotlib.pyplot must be imported as plt"
77 # find the last import statement involving matplotlib, and the first line that uses plt
78 mpl_last_import_index: int = -1
79 mpl_first_usage_index: int = -1
80 for i, line in enumerate(script_lines_new):
81 if "matplotlib" in line and (("import" in line) or ("from" in line)):
82 mpl_last_import_index = i
84 if "configure_notebook" in line:
85 mpl_last_import_index = i
87 if "plt." in line:
88 mpl_first_usage_index = i
90 assert (
91 mpl_last_import_index != -1
92 ), f"matplotlib imports not found! see line {mpl_last_import_index}"
93 if mpl_first_usage_index != -1:
94 assert (
95 mpl_first_usage_index > mpl_last_import_index
96 ), f"matplotlib plots created before import! see lines {mpl_first_usage_index}, {mpl_last_import_index}"
97 else:
98 warnings.warn(
99 "could not find where matplotlib is used, plot disabling might not work!"
100 )
102 # insert the cursed things
103 script_lines_new = (
104 script_lines_new[: mpl_last_import_index + 1]
105 + DISABLE_PLOTS["matplotlib"]
106 + script_lines_new[mpl_last_import_index + 1 :]
107 )
108 result_str_TEMP = "\n\n".join(script_lines_new)
110 if "circuitsvis" in result_str_TEMP:
111 # find the last import statement involving circuitsvis, and the first line that uses it
112 cirv_last_import_index: int = -1
113 cirv_first_usage_index: int = -1
115 for i, line in enumerate(script_lines_new):
116 if "circuitsvis" in line:
117 if (("import" in line) or ("from" in line)) and "circuitsvis" in line:
118 cirv_last_import_index = i
119 else:
120 cirv_first_usage_index = i
122 if "configure_notebook" in line:
123 mpl_last_import_index = i
125 if "render" in line:
126 cirv_first_usage_index = i
128 assert (
129 cirv_last_import_index != -1
130 ), f"circuitsvis imports not found! see line {cirv_last_import_index}"
131 if cirv_first_usage_index != -1:
132 assert (
133 cirv_first_usage_index > cirv_last_import_index
134 ), f"circuitsvis plots created before import! see lines {cirv_first_usage_index}, {cirv_last_import_index}"
135 else:
136 warnings.warn(
137 "could not find where circuitsvis is used, plot disabling might not work!"
138 )
140 # insert the cursed things
141 script_lines_new = (
142 script_lines_new[: cirv_last_import_index + 1]
143 + DISABLE_PLOTS["circuitsvis"]
144 + script_lines_new[cirv_last_import_index + 1 :]
145 )
146 result_str_TEMP = "\n\n".join(script_lines_new)
148 return script_lines_new
151def convert_ipynb(
152 notebook: dict,
153 strip_md_cells: bool = False,
154 header_comment: str = r"#%%",
155 disable_plots: bool = False,
156 filter_out_lines: str | typing.Sequence[str] = (
157 "%",
158 "!",
159 ), # ignore notebook magic commands and shell commands
160) -> str:
161 """Convert Jupyter Notebook to a script, doing some basic filtering and formatting.
163 # Arguments
164 - `notebook: dict`: Jupyter Notebook loaded as json.
165 - `strip_md_cells: bool = False`: Remove markdown cells from the output script.
166 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
167 - `disable_plots: bool = False`: Disable plots in the output script.
168 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
169 if a string is passed, it will be split by char and each char will be treated as a separate filter.
171 # Returns
172 - `str`: Converted script.
173 """
175 if isinstance(filter_out_lines, str):
176 filter_out_lines = tuple(filter_out_lines)
177 filter_out_lines_set: set = set(filter_out_lines)
179 result: list[str] = []
181 all_cells: list[dict] = notebook["cells"]
183 for cell in all_cells:
184 cell_type: str = cell["cell_type"]
186 if not strip_md_cells and cell_type == "markdown":
187 result.append(f'{header_comment}\n"""\n{"".join(cell["source"])}\n"""')
188 elif cell_type == "code":
189 source: list[str] = cell["source"]
190 if filter_out_lines:
191 source = [
192 (
193 f"#{line}"
194 if any(
195 line.startswith(filter_prefix)
196 for filter_prefix in filter_out_lines_set
197 )
198 else line
199 )
200 for line in source
201 ]
202 result.append(f'{header_comment}\n{"".join(source)}')
204 if disable_plots:
205 result = disable_plots_in_script(result)
206 result = DISABLE_PLOTS_WARNING + result
208 return "\n\n".join(result)
211def process_file(
212 in_file: str,
213 out_file: str | None = None,
214 strip_md_cells: bool = False,
215 header_comment: str = r"#%%",
216 disable_plots: bool = False,
217 filter_out_lines: str | typing.Sequence[str] = ("%", "!"),
218):
219 print(f"\tProcessing {in_file}...", file=sys.stderr)
220 assert os.path.exists(in_file), f"File {in_file} does not exist."
221 assert os.path.isfile(in_file), f"Path {in_file} is not a file."
222 assert in_file.endswith(".ipynb"), f"File {in_file} is not a Jupyter Notebook."
224 with open(in_file, "r") as file:
225 notebook: dict = json.load(file)
227 try:
228 converted_script: str = convert_ipynb(
229 notebook=notebook,
230 strip_md_cells=strip_md_cells,
231 header_comment=header_comment,
232 disable_plots=disable_plots,
233 filter_out_lines=filter_out_lines,
234 )
235 except AssertionError as e:
236 print(f"Error converting {in_file}: {e}", file=sys.stderr)
237 raise e
239 if out_file:
240 with open(out_file, "w") as file:
241 file.write(converted_script)
242 else:
243 print(converted_script)
246def process_dir(
247 input_dir: typing.Union[str, Path],
248 output_dir: typing.Union[str, Path],
249 strip_md_cells: bool = False,
250 header_comment: str = r"#%%",
251 disable_plots: bool = False,
252 filter_out_lines: str | typing.Sequence[str] = ("%", "!"),
253):
254 """Convert all Jupyter Notebooks in a directory to scripts.
256 # Arguments
257 - `input_dir: str`: Input directory.
258 - `output_dir: str`: Output directory.
259 - `strip_md_cells: bool = False`: Remove markdown cells from the output script.
260 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
261 - `disable_plots: bool = False`: Disable plots in the output script.
262 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
263 if a string is passed, it will be split by char and each char will be treated as a separate filter.
264 """
266 assert os.path.exists(input_dir), f"Directory {input_dir} does not exist."
267 assert os.path.isdir(input_dir), f"Path {input_dir} is not a directory."
269 if not os.path.exists(output_dir):
270 os.makedirs(output_dir, exist_ok=True)
272 filenames: list[str] = [
273 fname for fname in os.listdir(input_dir) if fname.endswith(".ipynb")
274 ]
276 assert filenames, f"Directory {input_dir} does not contain any Jupyter Notebooks."
277 n_files: int = len(filenames)
278 print(f"Converting {n_files} notebooks:", file=sys.stderr)
280 with SpinnerContext(
281 spinner_chars="braille",
282 update_interval=0.01,
283 format_string_when_updated=True,
284 output_stream=sys.stderr,
285 ) as spinner:
286 for idx, fname in enumerate(filenames):
287 spinner.update_value(f"\tConverting {idx+1}/{n_files}: {fname}")
288 in_file: str = os.path.join(input_dir, fname)
289 out_file: str = os.path.join(output_dir, fname.replace(".ipynb", ".py"))
291 with open(in_file, "r", encoding="utf-8") as file_in:
292 notebook: dict = json.load(file_in)
294 try:
295 converted_script: str = convert_ipynb(
296 notebook=notebook,
297 strip_md_cells=strip_md_cells,
298 header_comment=header_comment,
299 disable_plots=disable_plots,
300 filter_out_lines=filter_out_lines,
301 )
302 except AssertionError as e:
303 spinner.stop()
304 raise Exception(f"Error converting {in_file}") from e
306 with open(out_file, "w", encoding="utf-8") as file_out:
307 file_out.write(converted_script)
310if __name__ == "__main__":
311 parser = argparse.ArgumentParser(
312 description="Convert Jupyter Notebook to a script with cell separators."
313 )
314 parser.add_argument(
315 "in_path",
316 type=str,
317 help="Input Jupyter Notebook file (.ipynb) or directory of files.",
318 )
319 parser.add_argument(
320 "--out-file",
321 type=str,
322 help="Output script file. If not specified, the result will be printed to stdout.",
323 )
324 parser.add_argument(
325 "--output-dir", type=str, help="Output directory for converted script files."
326 )
327 parser.add_argument(
328 "--strip-md-cells",
329 action="store_true",
330 help="Remove markdown cells from the output script.",
331 )
332 parser.add_argument(
333 "--header-comment",
334 type=str,
335 default=r"#%%",
336 help="Comment string to separate cells in the output script.",
337 )
338 parser.add_argument(
339 "--disable-plots",
340 action="store_true",
341 help="Disable plots in the output script. Useful for testing in CI.",
342 )
343 parser.add_argument(
344 "--filter-out-lines",
345 type=str,
346 default="%",
347 help="Comment out lines starting with these characters.",
348 )
350 args = parser.parse_args()
352 if args.output_dir:
353 assert not args.out_file, "Cannot specify both --out_file and --output_dir."
354 process_dir(
355 input_dir=args.in_path,
356 output_dir=args.output_dir,
357 strip_md_cells=args.strip_md_cells,
358 header_comment=args.header_comment,
359 disable_plots=args.disable_plots,
360 filter_out_lines=args.filter_out_lines,
361 )
363 else:
364 process_file(
365 in_file=args.in_path,
366 out_file=args.out_file,
367 strip_md_cells=args.strip_md_cells,
368 header_comment=args.header_comment,
369 disable_plots=args.disable_plots,
370 filter_out_lines=args.filter_out_lines,
371 )
374print("muutils.nbutils.convert_ipynb_to_script.py loaded.")