Coverage for muutils / dbg.py: 87%
225 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
1"""
3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
5although it has been significantly modified
7licensed under MIT:
9Copyright (c) 2019 Tyler Wince
11Permission is hereby granted, free of charge, to any person obtaining a copy
12of this software and associated documentation files (the "Software"), to deal
13in the Software without restriction, including without limitation the rights
14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15copies of the Software, and to permit persons to whom the Software is
16furnished to do so, subject to the following conditions:
18The above copyright notice and this permission notice shall be included in
19all copies or substantial portions of the Software.
21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27THE SOFTWARE.
29"""
31from __future__ import annotations
33import inspect
34import sys
35import typing
36from pathlib import Path
37import re
39# type defs
40_ExpType = typing.TypeVar("_ExpType")
41_ExpType_dict = typing.TypeVar(
42 "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any]
43)
44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any])
47# TypedDict definitions for configuration dictionaries
48class DBGDictDefaultsType(typing.TypedDict):
49 key_types: bool
50 val_types: bool
51 max_len: int
52 indent: str
53 max_depth: int
56class DBGListDefaultsType(typing.TypedDict):
57 max_len: int
58 summary_show_types: bool
61class DBGTensorArraySummaryDefaultsType(typing.TypedDict):
62 fmt: typing.Literal["unicode", "latex", "ascii"]
63 precision: int
64 stats: bool
65 shape: bool
66 dtype: bool
67 device: bool
68 requires_grad: bool
69 sparkline: bool
70 sparkline_bins: int
71 sparkline_logy: typing.Union[None, bool]
72 colored: bool
73 eq_char: str
76# Sentinel type for no expression passed
77class _NoExpPassedSentinel:
78 """Unique sentinel type used to indicate that no expression was passed."""
80 pass
83_NoExpPassed = _NoExpPassedSentinel()
85# global variables
86_CWD: Path = Path.cwd().absolute()
87_COUNTER: int = 0
89# configuration
90PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
91DEFAULT_VAL_JOINER: str = " = "
94# path processing
95def _process_path(path: Path) -> str:
96 path_abs: Path = path.absolute()
97 fname: Path
98 if PATH_MODE == "absolute":
99 fname = path_abs
100 elif PATH_MODE == "relative":
101 try:
102 # if it's inside the cwd, print the relative path
103 fname = path.relative_to(_CWD)
104 except ValueError:
105 # if its not in the subpath, use the absolute path
106 fname = path_abs
107 else:
108 raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
110 return fname.as_posix()
113# actual dbg function
114@typing.overload
115def dbg() -> _NoExpPassedSentinel: ...
116@typing.overload
117def dbg(
118 exp: _NoExpPassedSentinel,
119 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
120 val_joiner: str = DEFAULT_VAL_JOINER,
121) -> _NoExpPassedSentinel: ...
122@typing.overload
123def dbg(
124 exp: _ExpType,
125 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
126 val_joiner: str = DEFAULT_VAL_JOINER,
127) -> _ExpType: ...
128def dbg(
129 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
130 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
131 val_joiner: str = DEFAULT_VAL_JOINER,
132) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
133 """Call dbg with any variable or expression.
135 Calling dbg will print to stderr the current filename and lineno,
136 as well as the passed expression and what the expression evaluates to:
138 from muutils.dbg import dbg
140 a = 2
141 b = 5
143 dbg(a+b)
145 def square(x: int) -> int:
146 return x * x
148 dbg(square(a))
150 """
151 global _COUNTER
153 # get the context
154 line_exp: str = "unknown"
155 current_file: str = "unknown"
156 dbg_frame: typing.Optional[inspect.FrameInfo] = None
157 for frame in inspect.stack():
158 if frame.code_context is None:
159 continue
160 line: str = frame.code_context[0]
161 if "dbg" in line:
162 current_file = _process_path(Path(frame.filename))
163 dbg_frame = frame
164 start: int = line.find("(") + 1
165 end: int = line.rfind(")")
166 if end == -1:
167 end = len(line)
168 line_exp = line[start:end]
169 break
171 fname: str = "unknown"
172 if current_file.startswith("/tmp/ipykernel_"):
173 stack: list[inspect.FrameInfo] = inspect.stack()
174 filtered_functions: list[str] = []
175 # this loop will find, in this order:
176 # - the dbg function call
177 # - the functions we care about displaying
178 # - `<module>`
179 # - a bunch of jupyter internals we don't care about
180 for frame_info in stack:
181 if _process_path(Path(frame_info.filename)) != current_file:
182 continue
183 if frame_info.function == "<module>":
184 break
185 if frame_info.function.startswith("dbg"):
186 continue
187 filtered_functions.append(frame_info.function)
188 if dbg_frame is not None:
189 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
190 else:
191 filtered_functions.append(current_file)
192 filtered_functions.reverse()
193 fname = " -> ".join(filtered_functions)
194 elif dbg_frame is not None:
195 fname = f"{current_file}:{dbg_frame.lineno}"
197 # assemble the message
198 msg: str
199 if exp is _NoExpPassed:
200 # if no expression is passed, just show location and counter value
201 msg = f"[ {fname} ] <dbg {_COUNTER}>"
202 _COUNTER += 1
203 else:
204 # if expression passed, format its value and show location, expr, and value
205 exp_val: str = formatter(exp) if formatter else repr(exp)
206 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
208 # print the message
209 print(
210 msg,
211 file=sys.stderr,
212 )
214 # return the expression itself
215 return exp
218# formatted `dbg_*` functions with their helpers
220DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = {
221 "fmt": "unicode",
222 "precision": 2,
223 "stats": True,
224 "shape": True,
225 "dtype": True,
226 "device": True,
227 "requires_grad": True,
228 "sparkline": True,
229 "sparkline_bins": 7,
230 "sparkline_logy": None, # None means auto-detect
231 "colored": True,
232 "eq_char": "=",
233}
236DBG_TENSOR_VAL_JOINER: str = ": "
239def tensor_info(tensor: typing.Any) -> str:
240 from muutils.tensor_info import array_summary
242 # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread)
243 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload]
246DBG_DICT_DEFAULTS: DBGDictDefaultsType = {
247 "key_types": True,
248 "val_types": True,
249 "max_len": 32,
250 "indent": " ",
251 "max_depth": 3,
252}
254DBG_LIST_DEFAULTS: DBGListDefaultsType = {
255 "max_len": 16,
256 "summary_show_types": True,
257}
260def list_info(
261 lst: typing.List[typing.Any],
262) -> str:
263 len_l: int = len(lst)
264 output: str
265 if len_l > DBG_LIST_DEFAULTS["max_len"]:
266 output = f"<list of len()={len_l}"
267 if DBG_LIST_DEFAULTS["summary_show_types"]:
268 val_types: typing.Set[str] = set(type(x).__name__ for x in lst)
269 output += f", types={{{', '.join(sorted(val_types))}}}"
270 output += ">"
271 else:
272 output = "[" + ", ".join(repr(x) for x in lst) + "]"
274 return output
277TENSOR_STR_TYPES: typing.Set[str] = {
278 "<class 'torch.Tensor'>",
279 "<class 'numpy.ndarray'>",
280}
283def dict_info(
284 d: typing.Dict[typing.Any, typing.Any],
285 depth: int = 0,
286) -> str:
287 len_d: int = len(d)
288 indent: str = DBG_DICT_DEFAULTS["indent"]
290 # summary line
291 output: str = f"{indent * depth}<dict of len()={len_d}"
293 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0:
294 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys())
295 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}"
296 output += f", key_types={key_types_str}"
298 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0:
299 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values())
300 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}"
301 output += f", val_types={val_types_str}"
303 output += ">"
305 # keys/values if not to deep and not too many
306 if depth < DBG_DICT_DEFAULTS["max_depth"]:
307 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]:
308 for k, v in d.items():
309 key_str: str = repr(k) if not isinstance(k, str) else k
311 val_str: str
312 val_type_str: str = str(type(v))
313 if isinstance(v, dict):
314 val_str = dict_info(v, depth + 1)
315 elif val_type_str in TENSOR_STR_TYPES:
316 val_str = tensor_info(v)
317 elif isinstance(v, list):
318 val_str = list_info(v)
319 else:
320 val_str = repr(v)
322 output += (
323 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}"
324 )
326 return output
329def info_auto(
330 obj: typing.Any,
331) -> str:
332 """Automatically format an object for debugging."""
333 if isinstance(obj, dict):
334 return dict_info(obj)
335 elif isinstance(obj, list):
336 return list_info(obj)
337 elif str(type(obj)) in TENSOR_STR_TYPES:
338 return tensor_info(obj)
339 else:
340 return repr(obj)
343def dbg_tensor(
344 tensor: _ExpType, # numpy array or torch tensor
345) -> _ExpType:
346 """dbg function for tensors, using tensor_info formatter."""
347 return dbg(
348 tensor,
349 formatter=tensor_info,
350 val_joiner=DBG_TENSOR_VAL_JOINER,
351 )
354def dbg_dict(
355 d: _ExpType_dict,
356) -> _ExpType_dict:
357 """dbg function for dictionaries, using dict_info formatter."""
358 return dbg(
359 d,
360 formatter=dict_info,
361 val_joiner=DBG_TENSOR_VAL_JOINER,
362 )
365def dbg_auto(
366 obj: _ExpType,
367) -> _ExpType:
368 """dbg function for automatic formatting based on type."""
369 return dbg(
370 obj,
371 formatter=info_auto,
372 val_joiner=DBG_TENSOR_VAL_JOINER,
373 )
376def _normalize_for_loose(text: str) -> str:
377 """Normalize text for loose matching by replacing non-alphanumeric chars with spaces."""
378 normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text)
379 return " ".join(normalized.split())
382def _compile_pattern(
383 pattern: str | re.Pattern[str],
384 *,
385 cased: bool = False,
386 loose: bool = False,
387) -> re.Pattern[str]:
388 """Compile pattern with appropriate flags for case sensitivity and loose matching."""
389 if isinstance(pattern, re.Pattern):
390 return pattern
392 # Start with no flags for case-insensitive default
393 flags: int = 0
394 if not cased:
395 flags |= re.IGNORECASE
397 if loose:
398 pattern = _normalize_for_loose(pattern)
400 return re.compile(pattern, flags)
403def grep_repr(
404 obj: typing.Any,
405 pattern: str | re.Pattern[str],
406 *,
407 char_context: int | None = 20,
408 line_context: int | None = None,
409 before_context: int = 0,
410 after_context: int = 0,
411 context: int | None = None,
412 max_count: int | None = None,
413 cased: bool = False,
414 loose: bool = False,
415 line_numbers: bool = False,
416 highlight: bool = True,
417 color: str = "31",
418 separator: str = "--",
419 quiet: bool = False,
420) -> typing.List[str] | None:
421 """grep-like search on ``repr(obj)`` with improved grep-style options.
423 By default, string patterns are case-insensitive. Pre-compiled regex
424 patterns use their own flags.
426 Parameters:
427 - obj: Object to search (its repr() string is scanned)
428 - pattern: Regular expression pattern (string or pre-compiled)
429 - char_context: Characters of context before/after each match (default: 20)
430 - line_context: Lines of context before/after; overrides char_context
431 - before_context: Lines of context before match (like grep -B)
432 - after_context: Lines of context after match (like grep -A)
433 - context: Lines of context before AND after (like grep -C)
434 - max_count: Stop after this many matches
435 - cased: Force case-sensitive search for string patterns
436 - loose: Normalize spaces/punctuation for flexible matching
437 - line_numbers: Show line numbers in output
438 - highlight: Wrap matches with ANSI color codes
439 - color: ANSI color code (default: "31" for red)
440 - separator: Separator between multiple matches
441 - quiet: Return results instead of printing
443 Returns:
444 - None if quiet=False (prints to stdout)
445 - List[str] if quiet=True (returns formatted output lines)
446 """
447 # Handle context parameter shortcuts
448 if context is not None:
449 before_context = after_context = context
451 # Prepare text and pattern
452 text: str = repr(obj)
453 if loose:
454 text = _normalize_for_loose(text)
456 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose)
458 def _color_match(segment: str) -> str:
459 if not highlight:
460 return segment
461 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment)
463 output_lines: list[str] = []
464 match_count: int = 0
466 # Determine if we're using line-based context
467 using_line_context = (
468 line_context is not None or before_context > 0 or after_context > 0
469 )
471 if using_line_context:
472 lines: list[str] = text.splitlines()
473 line_starts: list[int] = []
474 pos: int = 0
475 for line in lines:
476 line_starts.append(pos)
477 pos += len(line) + 1 # +1 for newline
479 processed_lines: set[int] = set()
481 for match in regex.finditer(text):
482 if max_count is not None and match_count >= max_count:
483 break
485 # Find which line contains this match
486 match_line = max(
487 i for i, start in enumerate(line_starts) if start <= match.start()
488 )
490 # Calculate context range
491 ctx_before: int
492 ctx_after: int
493 if line_context is not None:
494 ctx_before = ctx_after = line_context
495 else:
496 ctx_before, ctx_after = before_context, after_context
498 start_line: int = max(0, match_line - ctx_before)
499 end_line: int = min(len(lines), match_line + ctx_after + 1)
501 # Avoid duplicate output for overlapping contexts
502 line_range: set[int] = set(range(start_line, end_line))
503 if line_range & processed_lines:
504 continue
505 processed_lines.update(line_range)
507 # Format the context block
508 context_lines: list[str] = []
509 for i in range(start_line, end_line):
510 line_text = lines[i]
511 if line_numbers:
512 line_prefix = f"{i + 1}:"
513 line_text = f"{line_prefix}{line_text}"
514 context_lines.append(_color_match(line_text))
516 if output_lines and separator:
517 output_lines.append(separator)
518 output_lines.extend(context_lines)
519 match_count += 1
521 else:
522 # Character-based context
523 ctx: int = 0 if char_context is None else char_context
525 for match in regex.finditer(text):
526 if max_count is not None and match_count >= max_count:
527 break
529 start: int = max(0, match.start() - ctx)
530 end: int = min(len(text), match.end() + ctx)
531 snippet: str = text[start:end]
533 if output_lines and separator:
534 output_lines.append(separator)
535 output_lines.append(_color_match(snippet))
536 match_count += 1
538 if quiet:
539 return output_lines
540 else:
541 for line in output_lines:
542 print(line)
543 return None