Coverage for muutils/dbg.py: 69%

203 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-28 17:24 +0000

1""" 

2 

3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from 

4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 

5although it has been significantly modified 

6 

7licensed under MIT: 

8 

9Copyright (c) 2019 Tyler Wince 

10 

11Permission is hereby granted, free of charge, to any person obtaining a copy 

12of this software and associated documentation files (the "Software"), to deal 

13in the Software without restriction, including without limitation the rights 

14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

15copies of the Software, and to permit persons to whom the Software is 

16furnished to do so, subject to the following conditions: 

17 

18The above copyright notice and this permission notice shall be included in 

19all copies or substantial portions of the Software. 

20 

21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

27THE SOFTWARE. 

28 

29""" 

30 

31from __future__ import annotations 

32 

33import inspect 

34import sys 

35import typing 

36from pathlib import Path 

37import re 

38 

39# type defs 

40_ExpType = typing.TypeVar("_ExpType") 

41_ExpType_dict = typing.TypeVar( 

42 "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any] 

43) 

44_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any]) 

45 

46 

47# Sentinel type for no expression passed 

48class _NoExpPassedSentinel: 

49 """Unique sentinel type used to indicate that no expression was passed.""" 

50 

51 pass 

52 

53 

54_NoExpPassed = _NoExpPassedSentinel() 

55 

56# global variables 

57_CWD: Path = Path.cwd().absolute() 

58_COUNTER: int = 0 

59 

60# configuration 

61PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 

62DEFAULT_VAL_JOINER: str = " = " 

63 

64 

65# path processing 

66def _process_path(path: Path) -> str: 

67 path_abs: Path = path.absolute() 

68 fname: Path 

69 if PATH_MODE == "absolute": 

70 fname = path_abs 

71 elif PATH_MODE == "relative": 

72 try: 

73 # if it's inside the cwd, print the relative path 

74 fname = path.relative_to(_CWD) 

75 except ValueError: 

76 # if its not in the subpath, use the absolute path 

77 fname = path_abs 

78 else: 

79 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 

80 

81 return fname.as_posix() 

82 

83 

84# actual dbg function 

85@typing.overload 

86def dbg() -> _NoExpPassedSentinel: ... 

87@typing.overload 

88def dbg( 

89 exp: _NoExpPassedSentinel, 

90 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

91 val_joiner: str = DEFAULT_VAL_JOINER, 

92) -> _NoExpPassedSentinel: ... 

93@typing.overload 

94def dbg( 

95 exp: _ExpType, 

96 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

97 val_joiner: str = DEFAULT_VAL_JOINER, 

98) -> _ExpType: ... 

99def dbg( 

100 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 

101 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

102 val_joiner: str = DEFAULT_VAL_JOINER, 

103) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 

104 """Call dbg with any variable or expression. 

105 

106 Calling dbg will print to stderr the current filename and lineno, 

107 as well as the passed expression and what the expression evaluates to: 

108 

109 from muutils.dbg import dbg 

110 

111 a = 2 

112 b = 5 

113 

114 dbg(a+b) 

115 

116 def square(x: int) -> int: 

117 return x * x 

118 

119 dbg(square(a)) 

120 

121 """ 

122 global _COUNTER 

123 

124 # get the context 

125 line_exp: str = "unknown" 

126 current_file: str = "unknown" 

127 dbg_frame: typing.Optional[inspect.FrameInfo] = None 

128 for frame in inspect.stack(): 

129 if frame.code_context is None: 

130 continue 

131 line: str = frame.code_context[0] 

132 if "dbg" in line: 

133 current_file = _process_path(Path(frame.filename)) 

134 dbg_frame = frame 

135 start: int = line.find("(") + 1 

136 end: int = line.rfind(")") 

137 if end == -1: 

138 end = len(line) 

139 line_exp = line[start:end] 

140 break 

141 

142 fname: str = "unknown" 

143 if current_file.startswith("/tmp/ipykernel_"): 

144 stack: list[inspect.FrameInfo] = inspect.stack() 

145 filtered_functions: list[str] = [] 

146 # this loop will find, in this order: 

147 # - the dbg function call 

148 # - the functions we care about displaying 

149 # - `<module>` 

150 # - a bunch of jupyter internals we don't care about 

151 for frame_info in stack: 

152 if _process_path(Path(frame_info.filename)) != current_file: 

153 continue 

154 if frame_info.function == "<module>": 

155 break 

156 if frame_info.function.startswith("dbg"): 

157 continue 

158 filtered_functions.append(frame_info.function) 

159 if dbg_frame is not None: 

160 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}") 

161 else: 

162 filtered_functions.append(current_file) 

163 filtered_functions.reverse() 

164 fname = " -> ".join(filtered_functions) 

165 elif dbg_frame is not None: 

166 fname = f"{current_file}:{dbg_frame.lineno}" 

167 

168 # assemble the message 

169 msg: str 

170 if exp is _NoExpPassed: 

171 # if no expression is passed, just show location and counter value 

172 msg = f"[ {fname} ] <dbg {_COUNTER}>" 

173 _COUNTER += 1 

174 else: 

175 # if expression passed, format its value and show location, expr, and value 

176 exp_val: str = formatter(exp) if formatter else repr(exp) 

177 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 

178 

179 # print the message 

180 print( 

181 msg, 

182 file=sys.stderr, 

183 ) 

184 

185 # return the expression itself 

186 return exp 

187 

188 

189# formatted `dbg_*` functions with their helpers 

190 

191DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[ 

192 str, typing.Union[None, bool, int, str] 

193] = dict( 

194 fmt="unicode", 

195 precision=2, 

196 stats=True, 

197 shape=True, 

198 dtype=True, 

199 device=True, 

200 requires_grad=True, 

201 sparkline=True, 

202 sparkline_bins=7, 

203 sparkline_logy=None, # None means auto-detect 

204 colored=True, 

205 eq_char="=", 

206) 

207 

208 

209DBG_TENSOR_VAL_JOINER: str = ": " 

210 

211 

212def tensor_info(tensor: typing.Any) -> str: 

213 from muutils.tensor_info import array_summary 

214 

215 return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) 

216 

217 

218DBG_DICT_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict( 

219 key_types=True, 

220 val_types=True, 

221 max_len=32, 

222 indent=" ", 

223 max_depth=3, 

224) 

225 

226DBG_LIST_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict( 

227 max_len=16, 

228 summary_show_types=True, 

229) 

230 

231 

232def list_info( 

233 lst: typing.List[typing.Any], 

234) -> str: 

235 len_l: int = len(lst) 

236 output: str 

237 # TYPING: make `DBG_LIST_DEFAULTS` and the others typed dicts 

238 if len_l > DBG_LIST_DEFAULTS["max_len"]: # type: ignore[operator] 

239 output = f"<list of len()={len_l}" 

240 if DBG_LIST_DEFAULTS["summary_show_types"]: 

241 val_types: typing.Set[str] = set(type(x).__name__ for x in lst) 

242 output += f", types={{{', '.join(sorted(val_types))}}}" 

243 output += ">" 

244 else: 

245 output = "[" + ", ".join(repr(x) for x in lst) + "]" 

246 

247 return output 

248 

249 

250TENSOR_STR_TYPES: typing.Set[str] = { 

251 "<class 'torch.Tensor'>", 

252 "<class 'numpy.ndarray'>", 

253} 

254 

255 

256def dict_info( 

257 d: typing.Dict[typing.Any, typing.Any], 

258 depth: int = 0, 

259) -> str: 

260 len_d: int = len(d) 

261 indent: str = DBG_DICT_DEFAULTS["indent"] # type: ignore[assignment] 

262 

263 # summary line 

264 output: str = f"{indent * depth}<dict of len()={len_d}" 

265 

266 if DBG_DICT_DEFAULTS["key_types"] and len_d > 0: 

267 key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys()) 

268 key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}" 

269 output += f", key_types={key_types_str}" 

270 

271 if DBG_DICT_DEFAULTS["val_types"] and len_d > 0: 

272 val_types: typing.Set[str] = set(type(v).__name__ for v in d.values()) 

273 val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}" 

274 output += f", val_types={val_types_str}" 

275 

276 output += ">" 

277 

278 # keys/values if not to deep and not too many 

279 if depth < DBG_DICT_DEFAULTS["max_depth"]: # type: ignore[operator] 

280 if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: # type: ignore[operator] 

281 for k, v in d.items(): 

282 key_str: str = repr(k) if not isinstance(k, str) else k 

283 

284 val_str: str 

285 val_type_str: str = str(type(v)) 

286 if isinstance(v, dict): 

287 val_str = dict_info(v, depth + 1) 

288 elif val_type_str in TENSOR_STR_TYPES: 

289 val_str = tensor_info(v) 

290 elif isinstance(v, list): 

291 val_str = list_info(v) 

292 else: 

293 val_str = repr(v) 

294 

295 output += ( 

296 f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}" 

297 ) 

298 

299 return output 

300 

301 

302def info_auto( 

303 obj: typing.Any, 

304) -> str: 

305 """Automatically format an object for debugging.""" 

306 if isinstance(obj, dict): 

307 return dict_info(obj) 

308 elif isinstance(obj, list): 

309 return list_info(obj) 

310 elif str(type(obj)) in TENSOR_STR_TYPES: 

311 return tensor_info(obj) 

312 else: 

313 return repr(obj) 

314 

315 

316def dbg_tensor( 

317 tensor: _ExpType, # numpy array or torch tensor 

318) -> _ExpType: 

319 """dbg function for tensors, using tensor_info formatter.""" 

320 return dbg( 

321 tensor, 

322 formatter=tensor_info, 

323 val_joiner=DBG_TENSOR_VAL_JOINER, 

324 ) 

325 

326 

327def dbg_dict( 

328 d: _ExpType_dict, 

329) -> _ExpType_dict: 

330 """dbg function for dictionaries, using dict_info formatter.""" 

331 return dbg( 

332 d, 

333 formatter=dict_info, 

334 val_joiner=DBG_TENSOR_VAL_JOINER, 

335 ) 

336 

337 

338def dbg_auto( 

339 obj: _ExpType, 

340) -> _ExpType: 

341 """dbg function for automatic formatting based on type.""" 

342 return dbg( 

343 obj, 

344 formatter=info_auto, 

345 val_joiner=DBG_TENSOR_VAL_JOINER, 

346 ) 

347 

348 

349def _normalize_for_loose(text: str) -> str: 

350 """Normalize text for loose matching by replacing non-alphanumeric chars with spaces.""" 

351 normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text) 

352 return " ".join(normalized.split()) 

353 

354 

355def _compile_pattern( 

356 pattern: str | re.Pattern[str], 

357 *, 

358 cased: bool = False, 

359 loose: bool = False, 

360) -> re.Pattern[str]: 

361 """Compile pattern with appropriate flags for case sensitivity and loose matching.""" 

362 if isinstance(pattern, re.Pattern): 

363 return pattern 

364 

365 # Start with no flags for case-insensitive default 

366 flags: int = 0 

367 if not cased: 

368 flags |= re.IGNORECASE 

369 

370 if loose: 

371 pattern = _normalize_for_loose(pattern) 

372 

373 return re.compile(pattern, flags) 

374 

375 

376def grep_repr( 

377 obj: typing.Any, 

378 pattern: str | re.Pattern[str], 

379 *, 

380 char_context: int | None = 20, 

381 line_context: int | None = None, 

382 before_context: int = 0, 

383 after_context: int = 0, 

384 context: int | None = None, 

385 max_count: int | None = None, 

386 cased: bool = False, 

387 loose: bool = False, 

388 line_numbers: bool = False, 

389 highlight: bool = True, 

390 color: str = "31", 

391 separator: str = "--", 

392 quiet: bool = False, 

393) -> typing.List[str] | None: 

394 """grep-like search on ``repr(obj)`` with improved grep-style options. 

395 

396 By default, string patterns are case-insensitive. Pre-compiled regex 

397 patterns use their own flags. 

398 

399 Parameters: 

400 - obj: Object to search (its repr() string is scanned) 

401 - pattern: Regular expression pattern (string or pre-compiled) 

402 - char_context: Characters of context before/after each match (default: 20) 

403 - line_context: Lines of context before/after; overrides char_context 

404 - before_context: Lines of context before match (like grep -B) 

405 - after_context: Lines of context after match (like grep -A) 

406 - context: Lines of context before AND after (like grep -C) 

407 - max_count: Stop after this many matches 

408 - cased: Force case-sensitive search for string patterns 

409 - loose: Normalize spaces/punctuation for flexible matching 

410 - line_numbers: Show line numbers in output 

411 - highlight: Wrap matches with ANSI color codes 

412 - color: ANSI color code (default: "31" for red) 

413 - separator: Separator between multiple matches 

414 - quiet: Return results instead of printing 

415 

416 Returns: 

417 - None if quiet=False (prints to stdout) 

418 - List[str] if quiet=True (returns formatted output lines) 

419 """ 

420 # Handle context parameter shortcuts 

421 if context is not None: 

422 before_context = after_context = context 

423 

424 # Prepare text and pattern 

425 text: str = repr(obj) 

426 if loose: 

427 text = _normalize_for_loose(text) 

428 

429 regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose) 

430 

431 def _color_match(segment: str) -> str: 

432 if not highlight: 

433 return segment 

434 return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment) 

435 

436 output_lines: list[str] = [] 

437 match_count: int = 0 

438 

439 # Determine if we're using line-based context 

440 using_line_context = ( 

441 line_context is not None or before_context > 0 or after_context > 0 

442 ) 

443 

444 if using_line_context: 

445 lines: list[str] = text.splitlines() 

446 line_starts: list[int] = [] 

447 pos: int = 0 

448 for line in lines: 

449 line_starts.append(pos) 

450 pos += len(line) + 1 # +1 for newline 

451 

452 processed_lines: set[int] = set() 

453 

454 for match in regex.finditer(text): 

455 if max_count is not None and match_count >= max_count: 

456 break 

457 

458 # Find which line contains this match 

459 match_line = max( 

460 i for i, start in enumerate(line_starts) if start <= match.start() 

461 ) 

462 

463 # Calculate context range 

464 ctx_before: int 

465 ctx_after: int 

466 if line_context is not None: 

467 ctx_before = ctx_after = line_context 

468 else: 

469 ctx_before, ctx_after = before_context, after_context 

470 

471 start_line: int = max(0, match_line - ctx_before) 

472 end_line: int = min(len(lines), match_line + ctx_after + 1) 

473 

474 # Avoid duplicate output for overlapping contexts 

475 line_range: set[int] = set(range(start_line, end_line)) 

476 if line_range & processed_lines: 

477 continue 

478 processed_lines.update(line_range) 

479 

480 # Format the context block 

481 context_lines: list[str] = [] 

482 for i in range(start_line, end_line): 

483 line_text = lines[i] 

484 if line_numbers: 

485 line_prefix = f"{i + 1}:" 

486 line_text = f"{line_prefix}{line_text}" 

487 context_lines.append(_color_match(line_text)) 

488 

489 if output_lines and separator: 

490 output_lines.append(separator) 

491 output_lines.extend(context_lines) 

492 match_count += 1 

493 

494 else: 

495 # Character-based context 

496 ctx: int = 0 if char_context is None else char_context 

497 

498 for match in regex.finditer(text): 

499 if max_count is not None and match_count >= max_count: 

500 break 

501 

502 start: int = max(0, match.start() - ctx) 

503 end: int = min(len(text), match.end() + ctx) 

504 snippet: str = text[start:end] 

505 

506 if output_lines and separator: 

507 output_lines.append(separator) 

508 output_lines.append(_color_match(snippet)) 

509 match_count += 1 

510 

511 if quiet: 

512 return output_lines 

513 else: 

514 for line in output_lines: 

515 print(line) 

516 return None