Coverage for muutils / dbg.py: 87%

228 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 21:32 -0600

1""" 

2 

3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from 

4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 

5although it has been significantly modified 

6 

7licensed under MIT: 

8 

9Copyright (c) 2019 Tyler Wince 

10 

11Permission is hereby granted, free of charge, to any person obtaining a copy 

12of this software and associated documentation files (the "Software"), to deal 

13in the Software without restriction, including without limitation the rights 

14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

15copies of the Software, and to permit persons to whom the Software is 

16furnished to do so, subject to the following conditions: 

17 

18The above copyright notice and this permission notice shall be included in 

19all copies or substantial portions of the Software. 

20 

21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

27THE SOFTWARE. 

28 

29""" 

30 

31from __future__ import annotations 

32 

33import inspect 

34import sys 

35import typing 

36from pathlib import Path 

37import re 

38 

39# type defs 

40_ExpType = typing.TypeVar("_ExpType") 

41_ExpType_dict = typing.TypeVar( 

42 "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any] 

43) 

44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any]) 

45 

46 

47# TypedDict definitions for configuration dictionaries 

48class DBGDictDefaultsType(typing.TypedDict): 

49 key_types: bool 

50 val_types: bool 

51 max_len: int 

52 indent: str 

53 max_depth: int 

54 

55 

56class DBGListDefaultsType(typing.TypedDict): 

57 max_len: int 

58 summary_show_types: bool 

59 

60 

61class DBGTensorArraySummaryDefaultsType(typing.TypedDict): 

62 fmt: typing.Literal["unicode", "latex", "ascii"] 

63 precision: int 

64 stats: bool 

65 shape: bool 

66 dtype: bool 

67 device: bool 

68 requires_grad: bool 

69 sparkline: bool 

70 sparkline_bins: int 

71 sparkline_logy: typing.Union[None, bool] 

72 colored: bool 

73 eq_char: str 

74 

75 

76# Sentinel type for no expression passed 

77class _NoExpPassedSentinel: 

78 """Unique sentinel type used to indicate that no expression was passed.""" 

79 

80 pass 

81 

82 

83_NoExpPassed = _NoExpPassedSentinel() 

84 

85# global variables 

86_CWD: Path = Path.cwd().absolute() 

87_COUNTER: int = 0 

88_DBG_MODULE_FILE: str = str(Path(__file__).resolve()) 

89 

90# configuration 

91PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 

92DEFAULT_VAL_JOINER: str = " = " 

93 

94 

95# path processing 

96def _process_path(path: Path) -> str: 

97 path_abs: Path = path.absolute() 

98 fname: Path 

99 if PATH_MODE == "absolute": 

100 fname = path_abs 

101 elif PATH_MODE == "relative": 

102 try: 

103 # if it's inside the cwd, print the relative path 

104 fname = path.relative_to(_CWD) 

105 except ValueError: 

106 # if its not in the subpath, use the absolute path 

107 fname = path_abs 

108 else: 

109 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 

110 

111 return fname.as_posix() 

112 

113 

114# actual dbg function 

115@typing.overload 

116def dbg() -> _NoExpPassedSentinel: ... 

117@typing.overload 

118def dbg( 

119 exp: _NoExpPassedSentinel, 

120 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

121 val_joiner: str = DEFAULT_VAL_JOINER, 

122) -> _NoExpPassedSentinel: ... 

123@typing.overload 

124def dbg( 

125 exp: _ExpType, 

126 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

127 val_joiner: str = DEFAULT_VAL_JOINER, 

128) -> _ExpType: ... 

129def dbg( 

130 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 

131 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

132 val_joiner: str = DEFAULT_VAL_JOINER, 

133) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 

134 """Call dbg with any variable or expression. 

135 

136 Calling dbg will print to stderr the current filename and lineno, 

137 as well as the passed expression and what the expression evaluates to: 

138 

139 from muutils.dbg import dbg 

140 

141 a = 2 

142 b = 5 

143 

144 dbg(a+b) 

145 

146 def square(x: int) -> int: 

147 return x * x 

148 

149 dbg(square(a)) 

150 

151 """ 

152 global _COUNTER 

153 

154 # get the context 

155 line_exp: str = "unknown" 

156 current_file: str = "unknown" 

157 dbg_frame: typing.Optional[inspect.FrameInfo] = None 

158 for frame in inspect.stack(): 

159 if frame.code_context is None: 

160 continue 

161 if str(Path(frame.filename).resolve()) == _DBG_MODULE_FILE: 

162 continue 

163 line: str = frame.code_context[0] 

164 if "dbg" in line: 

165 current_file = _process_path(Path(frame.filename)) 

166 dbg_frame = frame 

167 start: int = line.find("(") + 1 

168 end: int = line.rfind(")") 

169 if end == -1: 

170 end = len(line) 

171 line_exp = line[start:end] 

172 break 

173 

174 fname: str = "unknown" 

175 if current_file.startswith("/tmp/ipykernel_"): 

176 stack: list[inspect.FrameInfo] = inspect.stack() 

177 filtered_functions: list[str] = [] 

178 # this loop will find, in this order: 

179 # - the dbg function call 

180 # - the functions we care about displaying 

181 # - `<module>` 

182 # - a bunch of jupyter internals we don't care about 

183 for frame_info in stack: 

184 if _process_path(Path(frame_info.filename)) != current_file: 

185 continue 

186 if frame_info.function == "<module>": 

187 break 

188 if frame_info.function.startswith("dbg"): 

189 continue 

190 filtered_functions.append(frame_info.function) 

191 if dbg_frame is not None: 

192 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}") 

193 else: 

194 filtered_functions.append(current_file) 

195 filtered_functions.reverse() 

196 fname = " -> ".join(filtered_functions) 

197 elif dbg_frame is not None: 

198 fname = f"{current_file}:{dbg_frame.lineno}" 

199 

200 # assemble the message 

201 msg: str 

202 if exp is _NoExpPassed: 

203 # if no expression is passed, just show location and counter value 

204 msg = f"[ {fname} ] <dbg {_COUNTER}>" 

205 _COUNTER += 1 

206 else: 

207 # if expression passed, format its value and show location, expr, and value 

208 exp_val: str = formatter(exp) if formatter else repr(exp) 

209 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 

210 

211 # print the message 

212 print( 

213 msg, 

214 file=sys.stderr, 

215 ) 

216 

217 # return the expression itself 

218 return exp 

219 

220 

221# formatted `dbg_*` functions with their helpers 

222 

223DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = { 

224 "fmt": "unicode", 

225 "precision": 2, 

226 "stats": True, 

227 "shape": True, 

228 "dtype": True, 

229 "device": True, 

230 "requires_grad": True, 

231 "sparkline": True, 

232 "sparkline_bins": 7, 

233 "sparkline_logy": None, # None means auto-detect 

234 "colored": True, 

235 "eq_char": "=", 

236} 

237 

238 

239DBG_TENSOR_VAL_JOINER: str = ": " 

240 

241 

242def tensor_info(tensor: typing.Any) -> str: 

243 from muutils.tensor_info import array_summary 

244 

245 # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread) 

246 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload] 

247 

248 

249DBG_DICT_DEFAULTS: DBGDictDefaultsType = { 

250 "key_types": True, 

251 "val_types": True, 

252 "max_len": 32, 

253 "indent": " ", 

254 "max_depth": 3, 

255} 

256 

257DBG_LIST_DEFAULTS: DBGListDefaultsType = { 

258 "max_len": 16, 

259 "summary_show_types": True, 

260} 

261 

262 

263def list_info( 

264 lst: typing.List[typing.Any], 

265) -> str: 

266 len_l: int = len(lst) 

267 output: str 

268 if len_l > DBG_LIST_DEFAULTS["max_len"]: 

269 output = f"<list of len()={len_l}" 

270 if DBG_LIST_DEFAULTS["summary_show_types"]: 

271 val_types: typing.Set[str] = set(type(x).__name__ for x in lst) 

272 output += f", types={{{', '.join(sorted(val_types))}}}" 

273 output += ">" 

274 else: 

275 output = "[" + ", ".join(repr(x) for x in lst) + "]" 

276 

277 return output 

278 

279 

280TENSOR_STR_TYPES: typing.Set[str] = { 

281 "<class 'torch.Tensor'>", 

282 "<class 'numpy.ndarray'>", 

283} 

284 

285 

286def dict_info( 

287 d: typing.Dict[typing.Any, typing.Any], 

288 depth: int = 0, 

289) -> str: 

290 len_d: int = len(d) 

291 indent: str = DBG_DICT_DEFAULTS["indent"] 

292 

293 # summary line 

294 output: str = f"{indent * depth}<dict of len()={len_d}" 

295 

296 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0: 

297 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys()) 

298 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}" 

299 output += f", key_types={key_types_str}" 

300 

301 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0: 

302 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values()) 

303 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}" 

304 output += f", val_types={val_types_str}" 

305 

306 output += ">" 

307 

308 # keys/values if not to deep and not too many 

309 if depth < DBG_DICT_DEFAULTS["max_depth"]: 

310 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: 

311 for k, v in d.items(): 

312 key_str: str = repr(k) if not isinstance(k, str) else k 

313 

314 val_str: str 

315 val_type_str: str = str(type(v)) 

316 if isinstance(v, dict): 

317 val_str = dict_info(v, depth + 1) 

318 elif val_type_str in TENSOR_STR_TYPES: 

319 val_str = tensor_info(v) 

320 elif isinstance(v, list): 

321 val_str = list_info(v) 

322 else: 

323 val_str = repr(v) 

324 

325 output += ( 

326 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}" 

327 ) 

328 

329 return output 

330 

331 

332def info_auto( 

333 obj: typing.Any, 

334) -> str: 

335 """Automatically format an object for debugging.""" 

336 if isinstance(obj, dict): 

337 return dict_info(obj) 

338 elif isinstance(obj, list): 

339 return list_info(obj) 

340 elif str(type(obj)) in TENSOR_STR_TYPES: 

341 return tensor_info(obj) 

342 else: 

343 return repr(obj) 

344 

345 

346def dbg_tensor( 

347 tensor: _ExpType, # numpy array or torch tensor 

348) -> _ExpType: 

349 """dbg function for tensors, using tensor_info formatter.""" 

350 return dbg( 

351 tensor, 

352 formatter=tensor_info, 

353 val_joiner=DBG_TENSOR_VAL_JOINER, 

354 ) 

355 

356 

357def dbg_dict( 

358 d: _ExpType_dict, 

359) -> _ExpType_dict: 

360 """dbg function for dictionaries, using dict_info formatter.""" 

361 return dbg( 

362 d, 

363 formatter=dict_info, 

364 val_joiner=DBG_TENSOR_VAL_JOINER, 

365 ) 

366 

367 

368def dbg_auto( 

369 obj: _ExpType, 

370) -> _ExpType: 

371 """dbg function for automatic formatting based on type.""" 

372 return dbg( 

373 obj, 

374 formatter=info_auto, 

375 val_joiner=DBG_TENSOR_VAL_JOINER, 

376 ) 

377 

378 

379def _normalize_for_loose(text: str) -> str: 

380 """Normalize text for loose matching by replacing non-alphanumeric chars with spaces.""" 

381 normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text) 

382 return " ".join(normalized.split()) 

383 

384 

385def _compile_pattern( 

386 pattern: str | re.Pattern[str], 

387 *, 

388 cased: bool = False, 

389 loose: bool = False, 

390) -> re.Pattern[str]: 

391 """Compile pattern with appropriate flags for case sensitivity and loose matching.""" 

392 if isinstance(pattern, re.Pattern): 

393 return pattern 

394 

395 # Start with no flags for case-insensitive default 

396 flags: int = 0 

397 if not cased: 

398 flags |= re.IGNORECASE 

399 

400 if loose: 

401 pattern = _normalize_for_loose(pattern) 

402 

403 return re.compile(pattern, flags) 

404 

405 

406def grep_repr( 

407 obj: typing.Any, 

408 pattern: str | re.Pattern[str], 

409 *, 

410 char_context: int | None = 20, 

411 line_context: int | None = None, 

412 before_context: int = 0, 

413 after_context: int = 0, 

414 context: int | None = None, 

415 max_count: int | None = None, 

416 cased: bool = False, 

417 loose: bool = False, 

418 line_numbers: bool = False, 

419 highlight: bool = True, 

420 color: str = "31", 

421 separator: str = "--", 

422 quiet: bool = False, 

423) -> typing.List[str] | None: 

424 """grep-like search on ``repr(obj)`` with improved grep-style options. 

425 

426 By default, string patterns are case-insensitive. Pre-compiled regex 

427 patterns use their own flags. 

428 

429 Parameters: 

430 - obj: Object to search (its repr() string is scanned) 

431 - pattern: Regular expression pattern (string or pre-compiled) 

432 - char_context: Characters of context before/after each match (default: 20) 

433 - line_context: Lines of context before/after; overrides char_context 

434 - before_context: Lines of context before match (like grep -B) 

435 - after_context: Lines of context after match (like grep -A) 

436 - context: Lines of context before AND after (like grep -C) 

437 - max_count: Stop after this many matches 

438 - cased: Force case-sensitive search for string patterns 

439 - loose: Normalize spaces/punctuation for flexible matching 

440 - line_numbers: Show line numbers in output 

441 - highlight: Wrap matches with ANSI color codes 

442 - color: ANSI color code (default: "31" for red) 

443 - separator: Separator between multiple matches 

444 - quiet: Return results instead of printing 

445 

446 Returns: 

447 - None if quiet=False (prints to stdout) 

448 - List[str] if quiet=True (returns formatted output lines) 

449 """ 

450 # Handle context parameter shortcuts 

451 if context is not None: 

452 before_context = after_context = context 

453 

454 # Prepare text and pattern 

455 text: str = repr(obj) 

456 if loose: 

457 text = _normalize_for_loose(text) 

458 

459 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose) 

460 

461 def _color_match(segment: str) -> str: 

462 if not highlight: 

463 return segment 

464 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment) 

465 

466 output_lines: list[str] = [] 

467 match_count: int = 0 

468 

469 # Determine if we're using line-based context 

470 using_line_context = ( 

471 line_context is not None or before_context > 0 or after_context > 0 

472 ) 

473 

474 if using_line_context: 

475 lines: list[str] = text.splitlines() 

476 line_starts: list[int] = [] 

477 pos: int = 0 

478 for line in lines: 

479 line_starts.append(pos) 

480 pos += len(line) + 1 # +1 for newline 

481 

482 processed_lines: set[int] = set() 

483 

484 for match in regex.finditer(text): 

485 if max_count is not None and match_count >= max_count: 

486 break 

487 

488 # Find which line contains this match 

489 match_line = max( 

490 i for i, start in enumerate(line_starts) if start <= match.start() 

491 ) 

492 

493 # Calculate context range 

494 ctx_before: int 

495 ctx_after: int 

496 if line_context is not None: 

497 ctx_before = ctx_after = line_context 

498 else: 

499 ctx_before, ctx_after = before_context, after_context 

500 

501 start_line: int = max(0, match_line - ctx_before) 

502 end_line: int = min(len(lines), match_line + ctx_after + 1) 

503 

504 # Avoid duplicate output for overlapping contexts 

505 line_range: set[int] = set(range(start_line, end_line)) 

506 if line_range & processed_lines: 

507 continue 

508 processed_lines.update(line_range) 

509 

510 # Format the context block 

511 context_lines: list[str] = [] 

512 for i in range(start_line, end_line): 

513 line_text = lines[i] 

514 if line_numbers: 

515 line_prefix = f"{i + 1}:" 

516 line_text = f"{line_prefix}{line_text}" 

517 context_lines.append(_color_match(line_text)) 

518 

519 if output_lines and separator: 

520 output_lines.append(separator) 

521 output_lines.extend(context_lines) 

522 match_count += 1 

523 

524 else: 

525 # Character-based context 

526 ctx: int = 0 if char_context is None else char_context 

527 

528 for match in regex.finditer(text): 

529 if max_count is not None and match_count >= max_count: 

530 break 

531 

532 start: int = max(0, match.start() - ctx) 

533 end: int = min(len(text), match.end() + ctx) 

534 snippet: str = text[start:end] 

535 

536 if output_lines and separator: 

537 output_lines.append(separator) 

538 output_lines.append(_color_match(snippet)) 

539 match_count += 1 

540 

541 if quiet: 

542 return output_lines 

543 else: 

544 for line in output_lines: 

545 print(line) 

546 return None