Coverage for muutils / dbg.py: 87%
228 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 21:32 -0600
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 21:32 -0600
1"""
3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
5although it has been significantly modified
7licensed under MIT:
9Copyright (c) 2019 Tyler Wince
11Permission is hereby granted, free of charge, to any person obtaining a copy
12of this software and associated documentation files (the "Software"), to deal
13in the Software without restriction, including without limitation the rights
14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15copies of the Software, and to permit persons to whom the Software is
16furnished to do so, subject to the following conditions:
18The above copyright notice and this permission notice shall be included in
19all copies or substantial portions of the Software.
21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27THE SOFTWARE.
29"""
31from __future__ import annotations
33import inspect
34import sys
35import typing
36from pathlib import Path
37import re
39# type defs
40_ExpType = typing.TypeVar("_ExpType")
41_ExpType_dict = typing.TypeVar(
42 "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any]
43)
44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any])
47# TypedDict definitions for configuration dictionaries
48class DBGDictDefaultsType(typing.TypedDict):
49 key_types: bool
50 val_types: bool
51 max_len: int
52 indent: str
53 max_depth: int
56class DBGListDefaultsType(typing.TypedDict):
57 max_len: int
58 summary_show_types: bool
61class DBGTensorArraySummaryDefaultsType(typing.TypedDict):
62 fmt: typing.Literal["unicode", "latex", "ascii"]
63 precision: int
64 stats: bool
65 shape: bool
66 dtype: bool
67 device: bool
68 requires_grad: bool
69 sparkline: bool
70 sparkline_bins: int
71 sparkline_logy: typing.Union[None, bool]
72 colored: bool
73 eq_char: str
76# Sentinel type for no expression passed
77class _NoExpPassedSentinel:
78 """Unique sentinel type used to indicate that no expression was passed."""
80 pass
83_NoExpPassed = _NoExpPassedSentinel()
85# global variables
86_CWD: Path = Path.cwd().absolute()
87_COUNTER: int = 0
88_DBG_MODULE_FILE: str = str(Path(__file__).resolve())
90# configuration
91PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
92DEFAULT_VAL_JOINER: str = " = "
95# path processing
96def _process_path(path: Path) -> str:
97 path_abs: Path = path.absolute()
98 fname: Path
99 if PATH_MODE == "absolute":
100 fname = path_abs
101 elif PATH_MODE == "relative":
102 try:
103 # if it's inside the cwd, print the relative path
104 fname = path.relative_to(_CWD)
105 except ValueError:
106 # if its not in the subpath, use the absolute path
107 fname = path_abs
108 else:
109 raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
111 return fname.as_posix()
114# actual dbg function
115@typing.overload
116def dbg() -> _NoExpPassedSentinel: ...
117@typing.overload
118def dbg(
119 exp: _NoExpPassedSentinel,
120 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
121 val_joiner: str = DEFAULT_VAL_JOINER,
122) -> _NoExpPassedSentinel: ...
123@typing.overload
124def dbg(
125 exp: _ExpType,
126 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
127 val_joiner: str = DEFAULT_VAL_JOINER,
128) -> _ExpType: ...
129def dbg(
130 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
131 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
132 val_joiner: str = DEFAULT_VAL_JOINER,
133) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
134 """Call dbg with any variable or expression.
136 Calling dbg will print to stderr the current filename and lineno,
137 as well as the passed expression and what the expression evaluates to:
139 from muutils.dbg import dbg
141 a = 2
142 b = 5
144 dbg(a+b)
146 def square(x: int) -> int:
147 return x * x
149 dbg(square(a))
151 """
152 global _COUNTER
154 # get the context
155 line_exp: str = "unknown"
156 current_file: str = "unknown"
157 dbg_frame: typing.Optional[inspect.FrameInfo] = None
158 for frame in inspect.stack():
159 if frame.code_context is None:
160 continue
161 if str(Path(frame.filename).resolve()) == _DBG_MODULE_FILE:
162 continue
163 line: str = frame.code_context[0]
164 if "dbg" in line:
165 current_file = _process_path(Path(frame.filename))
166 dbg_frame = frame
167 start: int = line.find("(") + 1
168 end: int = line.rfind(")")
169 if end == -1:
170 end = len(line)
171 line_exp = line[start:end]
172 break
174 fname: str = "unknown"
175 if current_file.startswith("/tmp/ipykernel_"):
176 stack: list[inspect.FrameInfo] = inspect.stack()
177 filtered_functions: list[str] = []
178 # this loop will find, in this order:
179 # - the dbg function call
180 # - the functions we care about displaying
181 # - `<module>`
182 # - a bunch of jupyter internals we don't care about
183 for frame_info in stack:
184 if _process_path(Path(frame_info.filename)) != current_file:
185 continue
186 if frame_info.function == "<module>":
187 break
188 if frame_info.function.startswith("dbg"):
189 continue
190 filtered_functions.append(frame_info.function)
191 if dbg_frame is not None:
192 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
193 else:
194 filtered_functions.append(current_file)
195 filtered_functions.reverse()
196 fname = " -> ".join(filtered_functions)
197 elif dbg_frame is not None:
198 fname = f"{current_file}:{dbg_frame.lineno}"
200 # assemble the message
201 msg: str
202 if exp is _NoExpPassed:
203 # if no expression is passed, just show location and counter value
204 msg = f"[ {fname} ] <dbg {_COUNTER}>"
205 _COUNTER += 1
206 else:
207 # if expression passed, format its value and show location, expr, and value
208 exp_val: str = formatter(exp) if formatter else repr(exp)
209 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
211 # print the message
212 print(
213 msg,
214 file=sys.stderr,
215 )
217 # return the expression itself
218 return exp
221# formatted `dbg_*` functions with their helpers
223DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = {
224 "fmt": "unicode",
225 "precision": 2,
226 "stats": True,
227 "shape": True,
228 "dtype": True,
229 "device": True,
230 "requires_grad": True,
231 "sparkline": True,
232 "sparkline_bins": 7,
233 "sparkline_logy": None, # None means auto-detect
234 "colored": True,
235 "eq_char": "=",
236}
239DBG_TENSOR_VAL_JOINER: str = ": "
242def tensor_info(tensor: typing.Any) -> str:
243 from muutils.tensor_info import array_summary
245 # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread)
246 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload]
249DBG_DICT_DEFAULTS: DBGDictDefaultsType = {
250 "key_types": True,
251 "val_types": True,
252 "max_len": 32,
253 "indent": " ",
254 "max_depth": 3,
255}
257DBG_LIST_DEFAULTS: DBGListDefaultsType = {
258 "max_len": 16,
259 "summary_show_types": True,
260}
263def list_info(
264 lst: typing.List[typing.Any],
265) -> str:
266 len_l: int = len(lst)
267 output: str
268 if len_l > DBG_LIST_DEFAULTS["max_len"]:
269 output = f"<list of len()={len_l}"
270 if DBG_LIST_DEFAULTS["summary_show_types"]:
271 val_types: typing.Set[str] = set(type(x).__name__ for x in lst)
272 output += f", types={{{', '.join(sorted(val_types))}}}"
273 output += ">"
274 else:
275 output = "[" + ", ".join(repr(x) for x in lst) + "]"
277 return output
280TENSOR_STR_TYPES: typing.Set[str] = {
281 "<class 'torch.Tensor'>",
282 "<class 'numpy.ndarray'>",
283}
286def dict_info(
287 d: typing.Dict[typing.Any, typing.Any],
288 depth: int = 0,
289) -> str:
290 len_d: int = len(d)
291 indent: str = DBG_DICT_DEFAULTS["indent"]
293 # summary line
294 output: str = f"{indent * depth}<dict of len()={len_d}"
296 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0:
297 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys())
298 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}"
299 output += f", key_types={key_types_str}"
301 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0:
302 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values())
303 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}"
304 output += f", val_types={val_types_str}"
306 output += ">"
308 # keys/values if not to deep and not too many
309 if depth < DBG_DICT_DEFAULTS["max_depth"]:
310 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]:
311 for k, v in d.items():
312 key_str: str = repr(k) if not isinstance(k, str) else k
314 val_str: str
315 val_type_str: str = str(type(v))
316 if isinstance(v, dict):
317 val_str = dict_info(v, depth + 1)
318 elif val_type_str in TENSOR_STR_TYPES:
319 val_str = tensor_info(v)
320 elif isinstance(v, list):
321 val_str = list_info(v)
322 else:
323 val_str = repr(v)
325 output += (
326 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}"
327 )
329 return output
332def info_auto(
333 obj: typing.Any,
334) -> str:
335 """Automatically format an object for debugging."""
336 if isinstance(obj, dict):
337 return dict_info(obj)
338 elif isinstance(obj, list):
339 return list_info(obj)
340 elif str(type(obj)) in TENSOR_STR_TYPES:
341 return tensor_info(obj)
342 else:
343 return repr(obj)
346def dbg_tensor(
347 tensor: _ExpType, # numpy array or torch tensor
348) -> _ExpType:
349 """dbg function for tensors, using tensor_info formatter."""
350 return dbg(
351 tensor,
352 formatter=tensor_info,
353 val_joiner=DBG_TENSOR_VAL_JOINER,
354 )
357def dbg_dict(
358 d: _ExpType_dict,
359) -> _ExpType_dict:
360 """dbg function for dictionaries, using dict_info formatter."""
361 return dbg(
362 d,
363 formatter=dict_info,
364 val_joiner=DBG_TENSOR_VAL_JOINER,
365 )
368def dbg_auto(
369 obj: _ExpType,
370) -> _ExpType:
371 """dbg function for automatic formatting based on type."""
372 return dbg(
373 obj,
374 formatter=info_auto,
375 val_joiner=DBG_TENSOR_VAL_JOINER,
376 )
379def _normalize_for_loose(text: str) -> str:
380 """Normalize text for loose matching by replacing non-alphanumeric chars with spaces."""
381 normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text)
382 return " ".join(normalized.split())
385def _compile_pattern(
386 pattern: str | re.Pattern[str],
387 *,
388 cased: bool = False,
389 loose: bool = False,
390) -> re.Pattern[str]:
391 """Compile pattern with appropriate flags for case sensitivity and loose matching."""
392 if isinstance(pattern, re.Pattern):
393 return pattern
395 # Start with no flags for case-insensitive default
396 flags: int = 0
397 if not cased:
398 flags |= re.IGNORECASE
400 if loose:
401 pattern = _normalize_for_loose(pattern)
403 return re.compile(pattern, flags)
406def grep_repr(
407 obj: typing.Any,
408 pattern: str | re.Pattern[str],
409 *,
410 char_context: int | None = 20,
411 line_context: int | None = None,
412 before_context: int = 0,
413 after_context: int = 0,
414 context: int | None = None,
415 max_count: int | None = None,
416 cased: bool = False,
417 loose: bool = False,
418 line_numbers: bool = False,
419 highlight: bool = True,
420 color: str = "31",
421 separator: str = "--",
422 quiet: bool = False,
423) -> typing.List[str] | None:
424 """grep-like search on ``repr(obj)`` with improved grep-style options.
426 By default, string patterns are case-insensitive. Pre-compiled regex
427 patterns use their own flags.
429 Parameters:
430 - obj: Object to search (its repr() string is scanned)
431 - pattern: Regular expression pattern (string or pre-compiled)
432 - char_context: Characters of context before/after each match (default: 20)
433 - line_context: Lines of context before/after; overrides char_context
434 - before_context: Lines of context before match (like grep -B)
435 - after_context: Lines of context after match (like grep -A)
436 - context: Lines of context before AND after (like grep -C)
437 - max_count: Stop after this many matches
438 - cased: Force case-sensitive search for string patterns
439 - loose: Normalize spaces/punctuation for flexible matching
440 - line_numbers: Show line numbers in output
441 - highlight: Wrap matches with ANSI color codes
442 - color: ANSI color code (default: "31" for red)
443 - separator: Separator between multiple matches
444 - quiet: Return results instead of printing
446 Returns:
447 - None if quiet=False (prints to stdout)
448 - List[str] if quiet=True (returns formatted output lines)
449 """
450 # Handle context parameter shortcuts
451 if context is not None:
452 before_context = after_context = context
454 # Prepare text and pattern
455 text: str = repr(obj)
456 if loose:
457 text = _normalize_for_loose(text)
459 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose)
461 def _color_match(segment: str) -> str:
462 if not highlight:
463 return segment
464 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment)
466 output_lines: list[str] = []
467 match_count: int = 0
469 # Determine if we're using line-based context
470 using_line_context = (
471 line_context is not None or before_context > 0 or after_context > 0
472 )
474 if using_line_context:
475 lines: list[str] = text.splitlines()
476 line_starts: list[int] = []
477 pos: int = 0
478 for line in lines:
479 line_starts.append(pos)
480 pos += len(line) + 1 # +1 for newline
482 processed_lines: set[int] = set()
484 for match in regex.finditer(text):
485 if max_count is not None and match_count >= max_count:
486 break
488 # Find which line contains this match
489 match_line = max(
490 i for i, start in enumerate(line_starts) if start <= match.start()
491 )
493 # Calculate context range
494 ctx_before: int
495 ctx_after: int
496 if line_context is not None:
497 ctx_before = ctx_after = line_context
498 else:
499 ctx_before, ctx_after = before_context, after_context
501 start_line: int = max(0, match_line - ctx_before)
502 end_line: int = min(len(lines), match_line + ctx_after + 1)
504 # Avoid duplicate output for overlapping contexts
505 line_range: set[int] = set(range(start_line, end_line))
506 if line_range & processed_lines:
507 continue
508 processed_lines.update(line_range)
510 # Format the context block
511 context_lines: list[str] = []
512 for i in range(start_line, end_line):
513 line_text = lines[i]
514 if line_numbers:
515 line_prefix = f"{i + 1}:"
516 line_text = f"{line_prefix}{line_text}"
517 context_lines.append(_color_match(line_text))
519 if output_lines and separator:
520 output_lines.append(separator)
521 output_lines.extend(context_lines)
522 match_count += 1
524 else:
525 # Character-based context
526 ctx: int = 0 if char_context is None else char_context
528 for match in regex.finditer(text):
529 if max_count is not None and match_count >= max_count:
530 break
532 start: int = max(0, match.start() - ctx)
533 end: int = min(len(text), match.end() + ctx)
534 snippet: str = text[start:end]
536 if output_lines and separator:
537 output_lines.append(separator)
538 output_lines.append(_color_match(snippet))
539 match_count += 1
541 if quiet:
542 return output_lines
543 else:
544 for line in output_lines:
545 print(line)
546 return None