Coverage for muutils/dbg.py: 69%
203 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
1"""
3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
5although it has been significantly modified
7licensed under MIT:
9Copyright (c) 2019 Tyler Wince
11Permission is hereby granted, free of charge, to any person obtaining a copy
12of this software and associated documentation files (the "Software"), to deal
13in the Software without restriction, including without limitation the rights
14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15copies of the Software, and to permit persons to whom the Software is
16furnished to do so, subject to the following conditions:
18The above copyright notice and this permission notice shall be included in
19all copies or substantial portions of the Software.
21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27THE SOFTWARE.
29"""
31from __future__ import annotations
33import inspect
34import sys
35import typing
36from pathlib import Path
37import re
39# type defs
40_ExpType = typing.TypeVar("_ExpType")
41_ExpType_dict = typing.TypeVar(
42 "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any]
43)
44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any])
47# Sentinel type for no expression passed
48class _NoExpPassedSentinel:
49 """Unique sentinel type used to indicate that no expression was passed."""
51 pass
54_NoExpPassed = _NoExpPassedSentinel()
56# global variables
57_CWD: Path = Path.cwd().absolute()
58_COUNTER: int = 0
60# configuration
61PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
62DEFAULT_VAL_JOINER: str = " = "
65# path processing
66def _process_path(path: Path) -> str:
67 path_abs: Path = path.absolute()
68 fname: Path
69 if PATH_MODE == "absolute":
70 fname = path_abs
71 elif PATH_MODE == "relative":
72 try:
73 # if it's inside the cwd, print the relative path
74 fname = path.relative_to(_CWD)
75 except ValueError:
76 # if its not in the subpath, use the absolute path
77 fname = path_abs
78 else:
79 raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
81 return fname.as_posix()
84# actual dbg function
85@typing.overload
86def dbg() -> _NoExpPassedSentinel: ...
87@typing.overload
88def dbg(
89 exp: _NoExpPassedSentinel,
90 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
91 val_joiner: str = DEFAULT_VAL_JOINER,
92) -> _NoExpPassedSentinel: ...
93@typing.overload
94def dbg(
95 exp: _ExpType,
96 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
97 val_joiner: str = DEFAULT_VAL_JOINER,
98) -> _ExpType: ...
99def dbg(
100 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
101 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
102 val_joiner: str = DEFAULT_VAL_JOINER,
103) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
104 """Call dbg with any variable or expression.
106 Calling dbg will print to stderr the current filename and lineno,
107 as well as the passed expression and what the expression evaluates to:
109 from muutils.dbg import dbg
111 a = 2
112 b = 5
114 dbg(a+b)
116 def square(x: int) -> int:
117 return x * x
119 dbg(square(a))
121 """
122 global _COUNTER
124 # get the context
125 line_exp: str = "unknown"
126 current_file: str = "unknown"
127 dbg_frame: typing.Optional[inspect.FrameInfo] = None
128 for frame in inspect.stack():
129 if frame.code_context is None:
130 continue
131 line: str = frame.code_context[0]
132 if "dbg" in line:
133 current_file = _process_path(Path(frame.filename))
134 dbg_frame = frame
135 start: int = line.find("(") + 1
136 end: int = line.rfind(")")
137 if end == -1:
138 end = len(line)
139 line_exp = line[start:end]
140 break
142 fname: str = "unknown"
143 if current_file.startswith("/tmp/ipykernel_"):
144 stack: list[inspect.FrameInfo] = inspect.stack()
145 filtered_functions: list[str] = []
146 # this loop will find, in this order:
147 # - the dbg function call
148 # - the functions we care about displaying
149 # - `<module>`
150 # - a bunch of jupyter internals we don't care about
151 for frame_info in stack:
152 if _process_path(Path(frame_info.filename)) != current_file:
153 continue
154 if frame_info.function == "<module>":
155 break
156 if frame_info.function.startswith("dbg"):
157 continue
158 filtered_functions.append(frame_info.function)
159 if dbg_frame is not None:
160 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
161 else:
162 filtered_functions.append(current_file)
163 filtered_functions.reverse()
164 fname = " -> ".join(filtered_functions)
165 elif dbg_frame is not None:
166 fname = f"{current_file}:{dbg_frame.lineno}"
168 # assemble the message
169 msg: str
170 if exp is _NoExpPassed:
171 # if no expression is passed, just show location and counter value
172 msg = f"[ {fname} ] <dbg {_COUNTER}>"
173 _COUNTER += 1
174 else:
175 # if expression passed, format its value and show location, expr, and value
176 exp_val: str = formatter(exp) if formatter else repr(exp)
177 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
179 # print the message
180 print(
181 msg,
182 file=sys.stderr,
183 )
185 # return the expression itself
186 return exp
189# formatted `dbg_*` functions with their helpers
191DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[
192 str, typing.Union[None, bool, int, str]
193] = dict(
194 fmt="unicode",
195 precision=2,
196 stats=True,
197 shape=True,
198 dtype=True,
199 device=True,
200 requires_grad=True,
201 sparkline=True,
202 sparkline_bins=7,
203 sparkline_logy=None, # None means auto-detect
204 colored=True,
205 eq_char="=",
206)
209DBG_TENSOR_VAL_JOINER: str = ": "
212def tensor_info(tensor: typing.Any) -> str:
213 from muutils.tensor_info import array_summary
215 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)
218DBG_DICT_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict(
219 key_types=True,
220 val_types=True,
221 max_len=32,
222 indent=" ",
223 max_depth=3,
224)
226DBG_LIST_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict(
227 max_len=16,
228 summary_show_types=True,
229)
232def list_info(
233 lst: typing.List[typing.Any],
234) -> str:
235 len_l: int = len(lst)
236 output: str
237 # TYPING: make `DBG_LIST_DEFAULTS` and the others typed dicts
238 if len_l > DBG_LIST_DEFAULTS["max_len"]: # type: ignore[operator]
239 output = f"<list of len()={len_l}"
240 if DBG_LIST_DEFAULTS["summary_show_types"]:
241 val_types: typing.Set[str] = set(type(x).__name__ for x in lst)
242 output += f", types={{{', '.join(sorted(val_types))}}}"
243 output += ">"
244 else:
245 output = "[" + ", ".join(repr(x) for x in lst) + "]"
247 return output
250TENSOR_STR_TYPES: typing.Set[str] = {
251 "<class 'torch.Tensor'>",
252 "<class 'numpy.ndarray'>",
253}
256def dict_info(
257 d: typing.Dict[typing.Any, typing.Any],
258 depth: int = 0,
259) -> str:
260 len_d: int = len(d)
261 indent: str = DBG_DICT_DEFAULTS["indent"] # type: ignore[assignment]
263 # summary line
264 output: str = f"{indent * depth}<dict of len()={len_d}"
266 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0:
267 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys())
268 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}"
269 output += f", key_types={key_types_str}"
271 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0:
272 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values())
273 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}"
274 output += f", val_types={val_types_str}"
276 output += ">"
278 # keys/values if not to deep and not too many
279 if depth < DBG_DICT_DEFAULTS["max_depth"]: # type: ignore[operator]
280 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: # type: ignore[operator]
281 for k, v in d.items():
282 key_str: str = repr(k) if not isinstance(k, str) else k
284 val_str: str
285 val_type_str: str = str(type(v))
286 if isinstance(v, dict):
287 val_str = dict_info(v, depth + 1)
288 elif val_type_str in TENSOR_STR_TYPES:
289 val_str = tensor_info(v)
290 elif isinstance(v, list):
291 val_str = list_info(v)
292 else:
293 val_str = repr(v)
295 output += (
296 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}"
297 )
299 return output
302def info_auto(
303 obj: typing.Any,
304) -> str:
305 """Automatically format an object for debugging."""
306 if isinstance(obj, dict):
307 return dict_info(obj)
308 elif isinstance(obj, list):
309 return list_info(obj)
310 elif str(type(obj)) in TENSOR_STR_TYPES:
311 return tensor_info(obj)
312 else:
313 return repr(obj)
316def dbg_tensor(
317 tensor: _ExpType, # numpy array or torch tensor
318) -> _ExpType:
319 """dbg function for tensors, using tensor_info formatter."""
320 return dbg(
321 tensor,
322 formatter=tensor_info,
323 val_joiner=DBG_TENSOR_VAL_JOINER,
324 )
327def dbg_dict(
328 d: _ExpType_dict,
329) -> _ExpType_dict:
330 """dbg function for dictionaries, using dict_info formatter."""
331 return dbg(
332 d,
333 formatter=dict_info,
334 val_joiner=DBG_TENSOR_VAL_JOINER,
335 )
338def dbg_auto(
339 obj: _ExpType,
340) -> _ExpType:
341 """dbg function for automatic formatting based on type."""
342 return dbg(
343 obj,
344 formatter=info_auto,
345 val_joiner=DBG_TENSOR_VAL_JOINER,
346 )
349def _normalize_for_loose(text: str) -> str:
350 """Normalize text for loose matching by replacing non-alphanumeric chars with spaces."""
351 normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text)
352 return " ".join(normalized.split())
355def _compile_pattern(
356 pattern: str | re.Pattern[str],
357 *,
358 cased: bool = False,
359 loose: bool = False,
360) -> re.Pattern[str]:
361 """Compile pattern with appropriate flags for case sensitivity and loose matching."""
362 if isinstance(pattern, re.Pattern):
363 return pattern
365 # Start with no flags for case-insensitive default
366 flags: int = 0
367 if not cased:
368 flags |= re.IGNORECASE
370 if loose:
371 pattern = _normalize_for_loose(pattern)
373 return re.compile(pattern, flags)
376def grep_repr(
377 obj: typing.Any,
378 pattern: str | re.Pattern[str],
379 *,
380 char_context: int | None = 20,
381 line_context: int | None = None,
382 before_context: int = 0,
383 after_context: int = 0,
384 context: int | None = None,
385 max_count: int | None = None,
386 cased: bool = False,
387 loose: bool = False,
388 line_numbers: bool = False,
389 highlight: bool = True,
390 color: str = "31",
391 separator: str = "--",
392 quiet: bool = False,
393) -> typing.List[str] | None:
394 """grep-like search on ``repr(obj)`` with improved grep-style options.
396 By default, string patterns are case-insensitive. Pre-compiled regex
397 patterns use their own flags.
399 Parameters:
400 - obj: Object to search (its repr() string is scanned)
401 - pattern: Regular expression pattern (string or pre-compiled)
402 - char_context: Characters of context before/after each match (default: 20)
403 - line_context: Lines of context before/after; overrides char_context
404 - before_context: Lines of context before match (like grep -B)
405 - after_context: Lines of context after match (like grep -A)
406 - context: Lines of context before AND after (like grep -C)
407 - max_count: Stop after this many matches
408 - cased: Force case-sensitive search for string patterns
409 - loose: Normalize spaces/punctuation for flexible matching
410 - line_numbers: Show line numbers in output
411 - highlight: Wrap matches with ANSI color codes
412 - color: ANSI color code (default: "31" for red)
413 - separator: Separator between multiple matches
414 - quiet: Return results instead of printing
416 Returns:
417 - None if quiet=False (prints to stdout)
418 - List[str] if quiet=True (returns formatted output lines)
419 """
420 # Handle context parameter shortcuts
421 if context is not None:
422 before_context = after_context = context
424 # Prepare text and pattern
425 text: str = repr(obj)
426 if loose:
427 text = _normalize_for_loose(text)
429 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose)
431 def _color_match(segment: str) -> str:
432 if not highlight:
433 return segment
434 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment)
436 output_lines: list[str] = []
437 match_count: int = 0
439 # Determine if we're using line-based context
440 using_line_context = (
441 line_context is not None or before_context > 0 or after_context > 0
442 )
444 if using_line_context:
445 lines: list[str] = text.splitlines()
446 line_starts: list[int] = []
447 pos: int = 0
448 for line in lines:
449 line_starts.append(pos)
450 pos += len(line) + 1 # +1 for newline
452 processed_lines: set[int] = set()
454 for match in regex.finditer(text):
455 if max_count is not None and match_count >= max_count:
456 break
458 # Find which line contains this match
459 match_line = max(
460 i for i, start in enumerate(line_starts) if start <= match.start()
461 )
463 # Calculate context range
464 ctx_before: int
465 ctx_after: int
466 if line_context is not None:
467 ctx_before = ctx_after = line_context
468 else:
469 ctx_before, ctx_after = before_context, after_context
471 start_line: int = max(0, match_line - ctx_before)
472 end_line: int = min(len(lines), match_line + ctx_after + 1)
474 # Avoid duplicate output for overlapping contexts
475 line_range: set[int] = set(range(start_line, end_line))
476 if line_range & processed_lines:
477 continue
478 processed_lines.update(line_range)
480 # Format the context block
481 context_lines: list[str] = []
482 for i in range(start_line, end_line):
483 line_text = lines[i]
484 if line_numbers:
485 line_prefix = f"{i + 1}:"
486 line_text = f"{line_prefix}{line_text}"
487 context_lines.append(_color_match(line_text))
489 if output_lines and separator:
490 output_lines.append(separator)
491 output_lines.extend(context_lines)
492 match_count += 1
494 else:
495 # Character-based context
496 ctx: int = 0 if char_context is None else char_context
498 for match in regex.finditer(text):
499 if max_count is not None and match_count >= max_count:
500 break
502 start: int = max(0, match.start() - ctx)
503 end: int = min(len(text), match.end() + ctx)
504 snippet: str = text[start:end]
506 if output_lines and separator:
507 output_lines.append(separator)
508 output_lines.append(_color_match(snippet))
509 match_count += 1
511 if quiet:
512 return output_lines
513 else:
514 for line in output_lines:
515 print(line)
516 return None