Coverage for muutils/tensor_info.py: 90%

250 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-07 20:16 -0700

1"get metadata about a tensor, mostly for `muutils.dbg`" 

2 

3from __future__ import annotations 

4 

5import numpy as np 

6from typing import Union, Any, Literal, List, Dict, overload, Optional 

7 

8# Global color definitions 

9COLORS: Dict[str, Dict[str, str]] = { 

10 "latex": { 

11 "range": r"\textcolor{purple}", 

12 "mean": r"\textcolor{teal}", 

13 "std": r"\textcolor{orange}", 

14 "median": r"\textcolor{green}", 

15 "warning": r"\textcolor{red}", 

16 "shape": r"\textcolor{magenta}", 

17 "dtype": r"\textcolor{gray}", 

18 "device": r"\textcolor{gray}", 

19 "requires_grad": r"\textcolor{gray}", 

20 "sparkline": r"\textcolor{blue}", 

21 "torch": r"\textcolor{orange}", 

22 "dtype_bool": r"\textcolor{gray}", 

23 "dtype_int": r"\textcolor{blue}", 

24 "dtype_float": r"\textcolor{red!70}", # 70% red intensity 

25 "dtype_str": r"\textcolor{red}", 

26 "device_cuda": r"\textcolor{green}", 

27 "reset": "", 

28 }, 

29 "terminal": { 

30 "range": "\033[35m", # purple 

31 "mean": "\033[36m", # cyan/teal 

32 "std": "\033[33m", # yellow/orange 

33 "median": "\033[32m", # green 

34 "warning": "\033[31m", # red 

35 "shape": "\033[95m", # bright magenta 

36 "dtype": "\033[90m", # gray 

37 "device": "\033[90m", # gray 

38 "requires_grad": "\033[90m", # gray 

39 "sparkline": "\033[34m", # blue 

40 "torch": "\033[38;5;208m", # bright orange 

41 "dtype_bool": "\033[38;5;245m", # medium grey 

42 "dtype_int": "\033[38;5;39m", # bright blue 

43 "dtype_float": "\033[38;5;167m", # softer red/coral 

44 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green 

45 "reset": "\033[0m", 

46 }, 

47 "none": { 

48 "range": "", 

49 "mean": "", 

50 "std": "", 

51 "median": "", 

52 "warning": "", 

53 "shape": "", 

54 "dtype": "", 

55 "device": "", 

56 "requires_grad": "", 

57 "sparkline": "", 

58 "torch": "", 

59 "dtype_bool": "", 

60 "dtype_int": "", 

61 "dtype_float": "", 

62 "dtype_str": "", 

63 "device_cuda": "", 

64 "reset": "", 

65 }, 

66} 

67 

68OutputFormat = Literal["unicode", "latex", "ascii"] 

69 

70SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 

71 "latex": { 

72 "range": r"\mathcal{R}", 

73 "mean": r"\mu", 

74 "std": r"\sigma", 

75 "median": r"\tilde{x}", 

76 "distribution": r"\mathbb{P}", 

77 "distribution_log": r"\mathbb{P}_L", 

78 "nan_values": r"\text{NANvals}", 

79 "warning": "!!!", 

80 "requires_grad": r"\nabla", 

81 "true": r"\checkmark", 

82 "false": r"\times", 

83 }, 

84 "unicode": { 

85 "range": "R", 

86 "mean": "μ", 

87 "std": "σ", 

88 "median": "x̃", 

89 "distribution": "ℙ", 

90 "distribution_log": "ℙ˪", 

91 "nan_values": "NANvals", 

92 "warning": "🚨", 

93 "requires_grad": "∇", 

94 "true": "✓", 

95 "false": "✗", 

96 }, 

97 "ascii": { 

98 "range": "range", 

99 "mean": "mean", 

100 "std": "std", 

101 "median": "med", 

102 "distribution": "dist", 

103 "distribution_log": "dist_log", 

104 "nan_values": "NANvals", 

105 "warning": "!!!", 

106 "requires_grad": "requires_grad", 

107 "true": "1", 

108 "false": "0", 

109 }, 

110} 

111"Symbols for different formats" 

112 

113SPARK_CHARS: Dict[OutputFormat, List[str]] = { 

114 "unicode": list(" ▁▂▃▄▅▆▇█"), 

115 "ascii": list(" _.-~=#"), 

116 "latex": list(" ▁▂▃▄▅▆▇█"), 

117} 

118"characters for sparklines in different formats" 

119 

120 

121def array_info( 

122 A: Any, 

123 hist_bins: int = 5, 

124) -> Dict[str, Any]: 

125 """Extract statistical information from an array-like object. 

126 

127 # Parameters: 

128 - `A : array-like` 

129 Array to analyze (numpy array or torch tensor) 

130 

131 # Returns: 

132 - `Dict[str, Any]` 

133 Dictionary containing raw statistical information with numeric values 

134 """ 

135 result: Dict[str, Any] = { 

136 "is_tensor": None, 

137 "device": None, 

138 "requires_grad": None, 

139 "shape": None, 

140 "dtype": None, 

141 "size": None, 

142 "has_nans": None, 

143 "nan_count": None, 

144 "nan_percent": None, 

145 "min": None, 

146 "max": None, 

147 "range": None, 

148 "mean": None, 

149 "std": None, 

150 "median": None, 

151 "histogram": None, 

152 "bins": None, 

153 "status": None, 

154 } 

155 

156 # Check if it's a tensor by looking at its class name 

157 # This avoids importing torch directly 

158 A_type: str = type(A).__name__ 

159 result["is_tensor"] = A_type == "Tensor" 

160 

161 # Try to get device information if it's a tensor 

162 if result["is_tensor"]: 

163 try: 

164 result["device"] = str(getattr(A, "device", None)) 

165 except: # noqa: E722 

166 pass 

167 

168 # Convert to numpy array for calculations 

169 try: 

170 # For PyTorch tensors 

171 if result["is_tensor"]: 

172 # Check if tensor is on GPU 

173 is_cuda: bool = False 

174 try: 

175 is_cuda = bool(getattr(A, "is_cuda", False)) 

176 except: # noqa: E722 

177 pass 

178 

179 if is_cuda: 

180 try: 

181 # Try to get CPU tensor first 

182 cpu_tensor = getattr(A, "cpu", lambda: A)() 

183 except: # noqa: E722 

184 A_np = np.array([]) 

185 else: 

186 cpu_tensor = A 

187 try: 

188 # For CPU tensor, just detach and convert 

189 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 

190 A_np = getattr(detached, "numpy", lambda: np.array([]))() 

191 except: # noqa: E722 

192 A_np = np.array([]) 

193 else: 

194 # For numpy arrays and other array-like objects 

195 A_np = np.asarray(A) 

196 except: # noqa: E722 

197 A_np = np.array([]) 

198 

199 # Get basic information 

200 try: 

201 result["shape"] = A_np.shape 

202 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 

203 result["size"] = A_np.size 

204 result["requires_grad"] = getattr(A, "requires_grad", None) 

205 except: # noqa: E722 

206 pass 

207 

208 # If array is empty, return early 

209 if result["size"] == 0: 

210 result["status"] = "empty array" 

211 return result 

212 

213 # Flatten array for statistics if it's multi-dimensional 

214 try: 

215 if len(A_np.shape) > 1: 

216 A_flat = A_np.flatten() 

217 else: 

218 A_flat = A_np 

219 except: # noqa: E722 

220 A_flat = A_np 

221 

222 # Check for NaN values 

223 try: 

224 nan_mask = np.isnan(A_flat) 

225 result["nan_count"] = np.sum(nan_mask) 

226 result["has_nans"] = result["nan_count"] > 0 

227 if result["size"] > 0: 

228 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 

229 except: # noqa: E722 

230 pass 

231 

232 # If all values are NaN, return early 

233 if result["has_nans"] and result["nan_count"] == result["size"]: 

234 result["status"] = "all NaN" 

235 return result 

236 

237 # Calculate statistics 

238 try: 

239 if result["has_nans"]: 

240 result["min"] = float(np.nanmin(A_flat)) 

241 result["max"] = float(np.nanmax(A_flat)) 

242 result["mean"] = float(np.nanmean(A_flat)) 

243 result["std"] = float(np.nanstd(A_flat)) 

244 result["median"] = float(np.nanmedian(A_flat)) 

245 result["range"] = (result["min"], result["max"]) 

246 

247 # Remove NaNs for histogram 

248 A_hist = A_flat[~nan_mask] 

249 else: 

250 result["min"] = float(np.min(A_flat)) 

251 result["max"] = float(np.max(A_flat)) 

252 result["mean"] = float(np.mean(A_flat)) 

253 result["std"] = float(np.std(A_flat)) 

254 result["median"] = float(np.median(A_flat)) 

255 result["range"] = (result["min"], result["max"]) 

256 

257 A_hist = A_flat 

258 

259 # Calculate histogram data for sparklines 

260 if A_hist.size > 0: 

261 try: 

262 # TODO: handle bool tensors correctly 

263 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. 

264 hist, bins = np.histogram(A_hist, bins=hist_bins) 

265 result["histogram"] = hist 

266 result["bins"] = bins 

267 except: # noqa: E722 

268 pass 

269 

270 result["status"] = "ok" 

271 except Exception as e: 

272 result["status"] = f"error: {str(e)}" 

273 

274 return result 

275 

276 

277def generate_sparkline( 

278 histogram: np.ndarray, 

279 format: Literal["unicode", "latex", "ascii"] = "unicode", 

280 log_y: Optional[bool] = None, 

281) -> tuple[str, bool]: 

282 """Generate a sparkline visualization of the histogram. 

283 

284 # Parameters: 

285 - `histogram : np.ndarray` 

286 Histogram data 

287 - `format : Literal["unicode", "latex", "ascii"]` 

288 Output format (defaults to `"unicode"`) 

289 - `log_y : bool|None` 

290 Whether to use logarithmic y-scale. `None` for automatic detection 

291 (defaults to `None`) 

292 

293 # Returns: 

294 - `tuple[str, bool]` 

295 Sparkline visualization and whether log scale was used 

296 """ 

297 if histogram is None or len(histogram) == 0: 

298 return "", False 

299 

300 # Get the appropriate character set 

301 chars: List[str] 

302 if format in SPARK_CHARS: 

303 chars = SPARK_CHARS[format] 

304 else: 

305 chars = SPARK_CHARS["ascii"] 

306 

307 # automatic detection of log_y 

308 if log_y is None: 

309 # we bin the histogram values to the number of levels in our sparkline characters 

310 hist_hist = np.histogram(histogram, bins=len(chars))[0] 

311 # if every bin except the smallest (first) and largest (last) is empty, 

312 # then we should use the log scale. if those bins are nonempty, keep the linear scale 

313 if hist_hist[1:-1].max() > 0: 

314 log_y = False 

315 else: 

316 log_y = True 

317 

318 # Handle log scale 

319 if log_y: 

320 # Add small value to avoid log(0) 

321 hist_data = np.log1p(histogram) 

322 else: 

323 hist_data = histogram 

324 

325 # Normalize to character set range 

326 if hist_data.max() > 0: 

327 normalized = hist_data / hist_data.max() * (len(chars) - 1) 

328 else: 

329 normalized = np.zeros_like(hist_data) 

330 

331 # Convert to characters 

332 spark = "" 

333 for val in normalized: 

334 idx = round(val) 

335 spark += chars[idx] 

336 

337 return spark, log_y 

338 

339 

340DEFAULT_SETTINGS: Dict[str, Any] = dict( 

341 fmt="unicode", 

342 precision=2, 

343 stats=True, 

344 shape=True, 

345 dtype=True, 

346 device=True, 

347 requires_grad=True, 

348 sparkline=False, 

349 sparkline_bins=5, 

350 sparkline_logy=None, 

351 colored=False, 

352 as_list=False, 

353 eq_char="=", 

354) 

355 

356 

357def apply_color( 

358 text: str, color_key: str, colors: Dict[str, str], using_tex: bool 

359) -> str: 

360 if using_tex: 

361 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text 

362 else: 

363 return ( 

364 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text 

365 ) 

366 

367 

368def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: 

369 """Colorize dtype string with specific colors for torch and type names.""" 

370 

371 # Handle torch prefix 

372 type_part: str = dtype_str 

373 prefix_part: Optional[str] = None 

374 if "torch." in dtype_str: 

375 parts = dtype_str.split("torch.") 

376 if len(parts) == 2: 

377 prefix_part = apply_color("torch", "torch", colors, using_tex) 

378 type_part = parts[1] 

379 

380 # Handle type coloring 

381 color_key: str = "dtype" 

382 if "bool" in dtype_str.lower(): 

383 color_key = "dtype_bool" 

384 elif "int" in dtype_str.lower(): 

385 color_key = "dtype_int" 

386 elif "float" in dtype_str.lower(): 

387 color_key = "dtype_float" 

388 

389 type_colored: str = apply_color(type_part, color_key, colors, using_tex) 

390 

391 if prefix_part: 

392 return f"{prefix_part}.{type_colored}" 

393 else: 

394 return type_colored 

395 

396 

397def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str: 

398 """Format shape with proper coloring for both 1D and multi-D arrays.""" 

399 

400 def apply_color(text: str, color_key: str) -> str: 

401 if using_tex: 

402 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text 

403 else: 

404 return ( 

405 f"{colors[color_key]}{text}{colors['reset']}" 

406 if colors[color_key] 

407 else text 

408 ) 

409 

410 if len(shape_val) == 1: 

411 # For 1D arrays, still color the dimension value 

412 return apply_color(str(shape_val[0]), "shape") 

413 else: 

414 # For multi-D arrays, color each dimension 

415 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")" 

416 

417 

418def format_device_colored( 

419 device_str: str, colors: Dict[str, str], using_tex: bool 

420) -> str: 

421 """Format device string with CUDA highlighting.""" 

422 

423 def apply_color(text: str, color_key: str) -> str: 

424 if using_tex: 

425 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text 

426 else: 

427 return ( 

428 f"{colors[color_key]}{text}{colors['reset']}" 

429 if colors[color_key] 

430 else text 

431 ) 

432 

433 if "cuda" in device_str.lower(): 

434 return apply_color(device_str, "device_cuda") 

435 else: 

436 return apply_color(device_str, "device") 

437 

438 

439class _UseDefaultType: 

440 pass 

441 

442 

443_USE_DEFAULT = _UseDefaultType() 

444 

445 

446@overload 

447def array_summary( 

448 as_list: Literal[True], 

449 **kwargs, 

450) -> List[str]: ... 

451@overload 

452def array_summary( 

453 as_list: Literal[False], 

454 **kwargs, 

455) -> str: ... 

456def array_summary( # type: ignore[misc] 

457 array, 

458 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 

459 precision: int = _USE_DEFAULT, # type: ignore[assignment] 

460 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 

461 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 

462 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 

463 device: bool = _USE_DEFAULT, # type: ignore[assignment] 

464 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 

465 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 

466 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 

467 sparkline_logy: Optional[bool] = _USE_DEFAULT, # type: ignore[assignment] 

468 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 

469 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 

470 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 

471) -> Union[str, List[str]]: 

472 """Format array information into a readable summary. 

473 

474 # Parameters: 

475 - `array` 

476 array-like object (numpy array or torch tensor) 

477 - `precision : int` 

478 Decimal places (defaults to `2`) 

479 - `format : Literal["unicode", "latex", "ascii"]` 

480 Output format (defaults to `{default_fmt}`) 

481 - `stats : bool` 

482 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 

483 - `shape : bool` 

484 Whether to include shape info (defaults to `True`) 

485 - `dtype : bool` 

486 Whether to include dtype info (defaults to `True`) 

487 - `device : bool` 

488 Whether to include device info for torch tensors (defaults to `True`) 

489 - `requires_grad : bool` 

490 Whether to include requires_grad info for torch tensors (defaults to `True`) 

491 - `sparkline : bool` 

492 Whether to include a sparkline visualization (defaults to `False`) 

493 - `sparkline_width : int` 

494 Width of the sparkline (defaults to `20`) 

495 - `sparkline_logy : bool|None` 

496 Whether to use logarithmic y-scale for sparkline (defaults to `None`) 

497 - `colored : bool` 

498 Whether to add color to output (defaults to `False`) 

499 - `as_list : bool` 

500 Whether to return as list of strings instead of joined string (defaults to `False`) 

501 

502 # Returns: 

503 - `Union[str, List[str]]` 

504 Formatted statistical summary, either as string or list of strings 

505 """ 

506 if fmt is _USE_DEFAULT: 

507 fmt = DEFAULT_SETTINGS["fmt"] 

508 if precision is _USE_DEFAULT: 

509 precision = DEFAULT_SETTINGS["precision"] 

510 if stats is _USE_DEFAULT: 

511 stats = DEFAULT_SETTINGS["stats"] 

512 if shape is _USE_DEFAULT: 

513 shape = DEFAULT_SETTINGS["shape"] 

514 if dtype is _USE_DEFAULT: 

515 dtype = DEFAULT_SETTINGS["dtype"] 

516 if device is _USE_DEFAULT: 

517 device = DEFAULT_SETTINGS["device"] 

518 if requires_grad is _USE_DEFAULT: 

519 requires_grad = DEFAULT_SETTINGS["requires_grad"] 

520 if sparkline is _USE_DEFAULT: 

521 sparkline = DEFAULT_SETTINGS["sparkline"] 

522 if sparkline_bins is _USE_DEFAULT: 

523 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 

524 if sparkline_logy is _USE_DEFAULT: 

525 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 

526 if colored is _USE_DEFAULT: 

527 colored = DEFAULT_SETTINGS["colored"] 

528 if as_list is _USE_DEFAULT: 

529 as_list = DEFAULT_SETTINGS["as_list"] 

530 if eq_char is _USE_DEFAULT: 

531 eq_char = DEFAULT_SETTINGS["eq_char"] 

532 

533 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 

534 result_parts: List[str] = [] 

535 using_tex: bool = fmt == "latex" 

536 

537 # Set color scheme based on format and colored flag 

538 colors: Dict[str, str] 

539 if colored: 

540 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 

541 else: 

542 colors = COLORS["none"] 

543 

544 # Get symbols for the current format 

545 symbols: Dict[str, str] = SYMBOLS[fmt] 

546 

547 # Helper function to colorize text 

548 def colorize(text: str, color_key: str) -> str: 

549 if using_tex: 

550 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text 

551 else: 

552 return ( 

553 f"{colors[color_key]}{text}{colors['reset']}" 

554 if colors[color_key] 

555 else text 

556 ) 

557 

558 # Check if dtype is integer type 

559 dtype_str: str = array_data.get("dtype", "") 

560 is_int_dtype: bool = any( 

561 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] 

562 ) 

563 

564 # Format string for numbers 

565 float_fmt: str = f".{precision}f" 

566 

567 # Handle error status or empty array 

568 if ( 

569 array_data["status"] in ["empty array", "all NaN", "unknown"] 

570 or array_data["size"] == 0 

571 ): 

572 status = array_data["status"] 

573 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 

574 else: 

575 # Add NaN warning at the beginning if there are NaNs 

576 if array_data["has_nans"]: 

577 _percent: str = "\\%" if using_tex else "%" 

578 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 

579 result_parts.append(colorize(nan_str, "warning")) 

580 

581 # Statistics 

582 if stats: 

583 for stat_key in ["mean", "std", "median"]: 

584 if array_data[stat_key] is not None: 

585 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 

586 stat_colored: str = colorize(stat_str, stat_key) 

587 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 

588 

589 # Range (min, max) 

590 if array_data["range"] is not None: 

591 min_val, max_val = array_data["range"] 

592 if is_int_dtype: 

593 min_str: str = f"{int(min_val):d}" 

594 max_str: str = f"{int(max_val):d}" 

595 else: 

596 min_str = f"{min_val:{float_fmt}}" 

597 max_str = f"{max_val:{float_fmt}}" 

598 min_colored: str = colorize(min_str, "range") 

599 max_colored: str = colorize(max_str, "range") 

600 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 

601 result_parts.append(range_str) 

602 

603 # Add sparkline if requested 

604 if sparkline and array_data["histogram"] is not None: 

605 # this should return whether log_y is used or not and then we set the symbol accordingly 

606 spark, used_log = generate_sparkline( 

607 array_data["histogram"], 

608 format=fmt, 

609 log_y=sparkline_logy, 

610 ) 

611 if spark: 

612 spark_colored = colorize(spark, "sparkline") 

613 dist_symbol = ( 

614 symbols["distribution_log"] if used_log else symbols["distribution"] 

615 ) 

616 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") 

617 

618 # Add shape if requested 

619 if shape and array_data["shape"]: 

620 shape_val = array_data["shape"] 

621 shape_str = format_shape_colored(shape_val, colors, using_tex) 

622 result_parts.append(f"shape{eq_char}{shape_str}") 

623 

624 # Add dtype if requested 

625 if dtype and array_data["dtype"]: 

626 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) 

627 result_parts.append(f"dtype={dtype_colored}") 

628 

629 # Add device if requested and it's a tensor with device info 

630 if device and array_data["is_tensor"] and array_data["device"]: 

631 device_colored = format_device_colored(array_data["device"], colors, using_tex) 

632 result_parts.append(f"device{eq_char}{device_colored}") 

633 

634 # Add gradient info 

635 if requires_grad and array_data["is_tensor"]: 

636 bool_req_grad_symb: str = ( 

637 symbols["true"] if array_data["requires_grad"] else symbols["false"] 

638 ) 

639 result_parts.append( 

640 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 

641 ) 

642 

643 # Return as list if requested, otherwise join with spaces 

644 if as_list: 

645 return result_parts 

646 else: 

647 joinchar: str = r" \quad " if using_tex else " " 

648 return joinchar.join(result_parts)