Coverage for muutils / dbg.py: 87%

225 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1""" 

2 

3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from 

4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 

5although it has been significantly modified 

6 

7licensed under MIT: 

8 

9Copyright (c) 2019 Tyler Wince 

10 

11Permission is hereby granted, free of charge, to any person obtaining a copy 

12of this software and associated documentation files (the "Software"), to deal 

13in the Software without restriction, including without limitation the rights 

14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

15copies of the Software, and to permit persons to whom the Software is 

16furnished to do so, subject to the following conditions: 

17 

18The above copyright notice and this permission notice shall be included in 

19all copies or substantial portions of the Software. 

20 

21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

27THE SOFTWARE. 

28 

29""" 

30 

31from __future__ import annotations 

32 

33import inspect 

34import sys 

35import typing 

36from pathlib import Path 

37import re 

38 

39# type defs 

40_ExpType = typing.TypeVar("_ExpType") 

41_ExpType_dict = typing.TypeVar( 

42 "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any] 

43) 

44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any]) 

45 

46 

47# TypedDict definitions for configuration dictionaries 

48class DBGDictDefaultsType(typing.TypedDict): 

49 key_types: bool 

50 val_types: bool 

51 max_len: int 

52 indent: str 

53 max_depth: int 

54 

55 

56class DBGListDefaultsType(typing.TypedDict): 

57 max_len: int 

58 summary_show_types: bool 

59 

60 

61class DBGTensorArraySummaryDefaultsType(typing.TypedDict): 

62 fmt: typing.Literal["unicode", "latex", "ascii"] 

63 precision: int 

64 stats: bool 

65 shape: bool 

66 dtype: bool 

67 device: bool 

68 requires_grad: bool 

69 sparkline: bool 

70 sparkline_bins: int 

71 sparkline_logy: typing.Union[None, bool] 

72 colored: bool 

73 eq_char: str 

74 

75 

76# Sentinel type for no expression passed 

77class _NoExpPassedSentinel: 

78 """Unique sentinel type used to indicate that no expression was passed.""" 

79 

80 pass 

81 

82 

83_NoExpPassed = _NoExpPassedSentinel() 

84 

85# global variables 

86_CWD: Path = Path.cwd().absolute() 

87_COUNTER: int = 0 

88 

89# configuration 

90PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 

91DEFAULT_VAL_JOINER: str = " = " 

92 

93 

94# path processing 

95def _process_path(path: Path) -> str: 

96 path_abs: Path = path.absolute() 

97 fname: Path 

98 if PATH_MODE == "absolute": 

99 fname = path_abs 

100 elif PATH_MODE == "relative": 

101 try: 

102 # if it's inside the cwd, print the relative path 

103 fname = path.relative_to(_CWD) 

104 except ValueError: 

105 # if its not in the subpath, use the absolute path 

106 fname = path_abs 

107 else: 

108 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 

109 

110 return fname.as_posix() 

111 

112 

113# actual dbg function 

114@typing.overload 

115def dbg() -> _NoExpPassedSentinel: ... 

116@typing.overload 

117def dbg( 

118 exp: _NoExpPassedSentinel, 

119 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

120 val_joiner: str = DEFAULT_VAL_JOINER, 

121) -> _NoExpPassedSentinel: ... 

122@typing.overload 

123def dbg( 

124 exp: _ExpType, 

125 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

126 val_joiner: str = DEFAULT_VAL_JOINER, 

127) -> _ExpType: ... 

128def dbg( 

129 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 

130 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

131 val_joiner: str = DEFAULT_VAL_JOINER, 

132) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 

133 """Call dbg with any variable or expression. 

134 

135 Calling dbg will print to stderr the current filename and lineno, 

136 as well as the passed expression and what the expression evaluates to: 

137 

138 from muutils.dbg import dbg 

139 

140 a = 2 

141 b = 5 

142 

143 dbg(a+b) 

144 

145 def square(x: int) -> int: 

146 return x * x 

147 

148 dbg(square(a)) 

149 

150 """ 

151 global _COUNTER 

152 

153 # get the context 

154 line_exp: str = "unknown" 

155 current_file: str = "unknown" 

156 dbg_frame: typing.Optional[inspect.FrameInfo] = None 

157 for frame in inspect.stack(): 

158 if frame.code_context is None: 

159 continue 

160 line: str = frame.code_context[0] 

161 if "dbg" in line: 

162 current_file = _process_path(Path(frame.filename)) 

163 dbg_frame = frame 

164 start: int = line.find("(") + 1 

165 end: int = line.rfind(")") 

166 if end == -1: 

167 end = len(line) 

168 line_exp = line[start:end] 

169 break 

170 

171 fname: str = "unknown" 

172 if current_file.startswith("/tmp/ipykernel_"): 

173 stack: list[inspect.FrameInfo] = inspect.stack() 

174 filtered_functions: list[str] = [] 

175 # this loop will find, in this order: 

176 # - the dbg function call 

177 # - the functions we care about displaying 

178 # - `<module>` 

179 # - a bunch of jupyter internals we don't care about 

180 for frame_info in stack: 

181 if _process_path(Path(frame_info.filename)) != current_file: 

182 continue 

183 if frame_info.function == "<module>": 

184 break 

185 if frame_info.function.startswith("dbg"): 

186 continue 

187 filtered_functions.append(frame_info.function) 

188 if dbg_frame is not None: 

189 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}") 

190 else: 

191 filtered_functions.append(current_file) 

192 filtered_functions.reverse() 

193 fname = " -> ".join(filtered_functions) 

194 elif dbg_frame is not None: 

195 fname = f"{current_file}:{dbg_frame.lineno}" 

196 

197 # assemble the message 

198 msg: str 

199 if exp is _NoExpPassed: 

200 # if no expression is passed, just show location and counter value 

201 msg = f"[ {fname} ] <dbg {_COUNTER}>" 

202 _COUNTER += 1 

203 else: 

204 # if expression passed, format its value and show location, expr, and value 

205 exp_val: str = formatter(exp) if formatter else repr(exp) 

206 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 

207 

208 # print the message 

209 print( 

210 msg, 

211 file=sys.stderr, 

212 ) 

213 

214 # return the expression itself 

215 return exp 

216 

217 

218# formatted `dbg_*` functions with their helpers 

219 

220DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = { 

221 "fmt": "unicode", 

222 "precision": 2, 

223 "stats": True, 

224 "shape": True, 

225 "dtype": True, 

226 "device": True, 

227 "requires_grad": True, 

228 "sparkline": True, 

229 "sparkline_bins": 7, 

230 "sparkline_logy": None, # None means auto-detect 

231 "colored": True, 

232 "eq_char": "=", 

233} 

234 

235 

236DBG_TENSOR_VAL_JOINER: str = ": " 

237 

238 

239def tensor_info(tensor: typing.Any) -> str: 

240 from muutils.tensor_info import array_summary 

241 

242 # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread) 

243 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload] 

244 

245 

246DBG_DICT_DEFAULTS: DBGDictDefaultsType = { 

247 "key_types": True, 

248 "val_types": True, 

249 "max_len": 32, 

250 "indent": " ", 

251 "max_depth": 3, 

252} 

253 

254DBG_LIST_DEFAULTS: DBGListDefaultsType = { 

255 "max_len": 16, 

256 "summary_show_types": True, 

257} 

258 

259 

260def list_info( 

261 lst: typing.List[typing.Any], 

262) -> str: 

263 len_l: int = len(lst) 

264 output: str 

265 if len_l > DBG_LIST_DEFAULTS["max_len"]: 

266 output = f"<list of len()={len_l}" 

267 if DBG_LIST_DEFAULTS["summary_show_types"]: 

268 val_types: typing.Set[str] = set(type(x).__name__ for x in lst) 

269 output += f", types={{{', '.join(sorted(val_types))}}}" 

270 output += ">" 

271 else: 

272 output = "[" + ", ".join(repr(x) for x in lst) + "]" 

273 

274 return output 

275 

276 

277TENSOR_STR_TYPES: typing.Set[str] = { 

278 "<class 'torch.Tensor'>", 

279 "<class 'numpy.ndarray'>", 

280} 

281 

282 

283def dict_info( 

284 d: typing.Dict[typing.Any, typing.Any], 

285 depth: int = 0, 

286) -> str: 

287 len_d: int = len(d) 

288 indent: str = DBG_DICT_DEFAULTS["indent"] 

289 

290 # summary line 

291 output: str = f"{indent * depth}<dict of len()={len_d}" 

292 

293 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0: 

294 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys()) 

295 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}" 

296 output += f", key_types={key_types_str}" 

297 

298 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0: 

299 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values()) 

300 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}" 

301 output += f", val_types={val_types_str}" 

302 

303 output += ">" 

304 

305 # keys/values if not to deep and not too many 

306 if depth < DBG_DICT_DEFAULTS["max_depth"]: 

307 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: 

308 for k, v in d.items(): 

309 key_str: str = repr(k) if not isinstance(k, str) else k 

310 

311 val_str: str 

312 val_type_str: str = str(type(v)) 

313 if isinstance(v, dict): 

314 val_str = dict_info(v, depth + 1) 

315 elif val_type_str in TENSOR_STR_TYPES: 

316 val_str = tensor_info(v) 

317 elif isinstance(v, list): 

318 val_str = list_info(v) 

319 else: 

320 val_str = repr(v) 

321 

322 output += ( 

323 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}" 

324 ) 

325 

326 return output 

327 

328 

329def info_auto( 

330 obj: typing.Any, 

331) -> str: 

332 """Automatically format an object for debugging.""" 

333 if isinstance(obj, dict): 

334 return dict_info(obj) 

335 elif isinstance(obj, list): 

336 return list_info(obj) 

337 elif str(type(obj)) in TENSOR_STR_TYPES: 

338 return tensor_info(obj) 

339 else: 

340 return repr(obj) 

341 

342 

343def dbg_tensor( 

344 tensor: _ExpType, # numpy array or torch tensor 

345) -> _ExpType: 

346 """dbg function for tensors, using tensor_info formatter.""" 

347 return dbg( 

348 tensor, 

349 formatter=tensor_info, 

350 val_joiner=DBG_TENSOR_VAL_JOINER, 

351 ) 

352 

353 

354def dbg_dict( 

355 d: _ExpType_dict, 

356) -> _ExpType_dict: 

357 """dbg function for dictionaries, using dict_info formatter.""" 

358 return dbg( 

359 d, 

360 formatter=dict_info, 

361 val_joiner=DBG_TENSOR_VAL_JOINER, 

362 ) 

363 

364 

365def dbg_auto( 

366 obj: _ExpType, 

367) -> _ExpType: 

368 """dbg function for automatic formatting based on type.""" 

369 return dbg( 

370 obj, 

371 formatter=info_auto, 

372 val_joiner=DBG_TENSOR_VAL_JOINER, 

373 ) 

374 

375 

376def _normalize_for_loose(text: str) -> str: 

377 """Normalize text for loose matching by replacing non-alphanumeric chars with spaces.""" 

378 normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text) 

379 return " ".join(normalized.split()) 

380 

381 

382def _compile_pattern( 

383 pattern: str | re.Pattern[str], 

384 *, 

385 cased: bool = False, 

386 loose: bool = False, 

387) -> re.Pattern[str]: 

388 """Compile pattern with appropriate flags for case sensitivity and loose matching.""" 

389 if isinstance(pattern, re.Pattern): 

390 return pattern 

391 

392 # Start with no flags for case-insensitive default 

393 flags: int = 0 

394 if not cased: 

395 flags |= re.IGNORECASE 

396 

397 if loose: 

398 pattern = _normalize_for_loose(pattern) 

399 

400 return re.compile(pattern, flags) 

401 

402 

403def grep_repr( 

404 obj: typing.Any, 

405 pattern: str | re.Pattern[str], 

406 *, 

407 char_context: int | None = 20, 

408 line_context: int | None = None, 

409 before_context: int = 0, 

410 after_context: int = 0, 

411 context: int | None = None, 

412 max_count: int | None = None, 

413 cased: bool = False, 

414 loose: bool = False, 

415 line_numbers: bool = False, 

416 highlight: bool = True, 

417 color: str = "31", 

418 separator: str = "--", 

419 quiet: bool = False, 

420) -> typing.List[str] | None: 

421 """grep-like search on ``repr(obj)`` with improved grep-style options. 

422 

423 By default, string patterns are case-insensitive. Pre-compiled regex 

424 patterns use their own flags. 

425 

426 Parameters: 

427 - obj: Object to search (its repr() string is scanned) 

428 - pattern: Regular expression pattern (string or pre-compiled) 

429 - char_context: Characters of context before/after each match (default: 20) 

430 - line_context: Lines of context before/after; overrides char_context 

431 - before_context: Lines of context before match (like grep -B) 

432 - after_context: Lines of context after match (like grep -A) 

433 - context: Lines of context before AND after (like grep -C) 

434 - max_count: Stop after this many matches 

435 - cased: Force case-sensitive search for string patterns 

436 - loose: Normalize spaces/punctuation for flexible matching 

437 - line_numbers: Show line numbers in output 

438 - highlight: Wrap matches with ANSI color codes 

439 - color: ANSI color code (default: "31" for red) 

440 - separator: Separator between multiple matches 

441 - quiet: Return results instead of printing 

442 

443 Returns: 

444 - None if quiet=False (prints to stdout) 

445 - List[str] if quiet=True (returns formatted output lines) 

446 """ 

447 # Handle context parameter shortcuts 

448 if context is not None: 

449 before_context = after_context = context 

450 

451 # Prepare text and pattern 

452 text: str = repr(obj) 

453 if loose: 

454 text = _normalize_for_loose(text) 

455 

456 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose) 

457 

458 def _color_match(segment: str) -> str: 

459 if not highlight: 

460 return segment 

461 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment) 

462 

463 output_lines: list[str] = [] 

464 match_count: int = 0 

465 

466 # Determine if we're using line-based context 

467 using_line_context = ( 

468 line_context is not None or before_context > 0 or after_context > 0 

469 ) 

470 

471 if using_line_context: 

472 lines: list[str] = text.splitlines() 

473 line_starts: list[int] = [] 

474 pos: int = 0 

475 for line in lines: 

476 line_starts.append(pos) 

477 pos += len(line) + 1 # +1 for newline 

478 

479 processed_lines: set[int] = set() 

480 

481 for match in regex.finditer(text): 

482 if max_count is not None and match_count >= max_count: 

483 break 

484 

485 # Find which line contains this match 

486 match_line = max( 

487 i for i, start in enumerate(line_starts) if start <= match.start() 

488 ) 

489 

490 # Calculate context range 

491 ctx_before: int 

492 ctx_after: int 

493 if line_context is not None: 

494 ctx_before = ctx_after = line_context 

495 else: 

496 ctx_before, ctx_after = before_context, after_context 

497 

498 start_line: int = max(0, match_line - ctx_before) 

499 end_line: int = min(len(lines), match_line + ctx_after + 1) 

500 

501 # Avoid duplicate output for overlapping contexts 

502 line_range: set[int] = set(range(start_line, end_line)) 

503 if line_range & processed_lines: 

504 continue 

505 processed_lines.update(line_range) 

506 

507 # Format the context block 

508 context_lines: list[str] = [] 

509 for i in range(start_line, end_line): 

510 line_text = lines[i] 

511 if line_numbers: 

512 line_prefix = f"{i + 1}:" 

513 line_text = f"{line_prefix}{line_text}" 

514 context_lines.append(_color_match(line_text)) 

515 

516 if output_lines and separator: 

517 output_lines.append(separator) 

518 output_lines.extend(context_lines) 

519 match_count += 1 

520 

521 else: 

522 # Character-based context 

523 ctx: int = 0 if char_context is None else char_context 

524 

525 for match in regex.finditer(text): 

526 if max_count is not None and match_count >= max_count: 

527 break 

528 

529 start: int = max(0, match.start() - ctx) 

530 end: int = min(len(text), match.end() + ctx) 

531 snippet: str = text[start:end] 

532 

533 if output_lines and separator: 

534 output_lines.append(separator) 

535 output_lines.append(_color_match(snippet)) 

536 match_count += 1 

537 

538 if quiet: 

539 return output_lines 

540 else: 

541 for line in output_lines: 

542 print(line) 

543 return None