Coverage for muutils/tensor_info.py: 89%

246 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-28 17:24 +0000

1"get metadata about a tensor, mostly for `muutils.dbg`" 

2 

3from __future__ import annotations 

4 

5import numpy as np 

6from typing import Union, Any, Literal, List, Dict, overload, Optional 

7 

8# Global color definitions 

9COLORS: Dict[str, Dict[str, str]] = { 

10 "latex": { 

11 "range": r"\textcolor{purple}", 

12 "mean": r"\textcolor{teal}", 

13 "std": r"\textcolor{orange}", 

14 "median": r"\textcolor{green}", 

15 "warning": r"\textcolor{red}", 

16 "shape": r"\textcolor{magenta}", 

17 "dtype": r"\textcolor{gray}", 

18 "device": r"\textcolor{gray}", 

19 "requires_grad": r"\textcolor{gray}", 

20 "sparkline": r"\textcolor{blue}", 

21 "torch": r"\textcolor{orange}", 

22 "dtype_bool": r"\textcolor{gray}", 

23 "dtype_int": r"\textcolor{blue}", 

24 "dtype_float": r"\textcolor{red!70}", # 70% red intensity 

25 "dtype_str": r"\textcolor{red}", 

26 "device_cuda": r"\textcolor{green}", 

27 "reset": "", 

28 }, 

29 "terminal": { 

30 "range": "\033[35m", # purple 

31 "mean": "\033[36m", # cyan/teal 

32 "std": "\033[33m", # yellow/orange 

33 "median": "\033[32m", # green 

34 "warning": "\033[31m", # red 

35 "shape": "\033[95m", # bright magenta 

36 "dtype": "\033[90m", # gray 

37 "device": "\033[90m", # gray 

38 "requires_grad": "\033[90m", # gray 

39 "sparkline": "\033[34m", # blue 

40 "torch": "\033[38;5;208m", # bright orange 

41 "dtype_bool": "\033[38;5;245m", # medium grey 

42 "dtype_int": "\033[38;5;39m", # bright blue 

43 "dtype_float": "\033[38;5;167m", # softer red/coral 

44 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green 

45 "reset": "\033[0m", 

46 }, 

47 "none": { 

48 "range": "", 

49 "mean": "", 

50 "std": "", 

51 "median": "", 

52 "warning": "", 

53 "shape": "", 

54 "dtype": "", 

55 "device": "", 

56 "requires_grad": "", 

57 "sparkline": "", 

58 "torch": "", 

59 "dtype_bool": "", 

60 "dtype_int": "", 

61 "dtype_float": "", 

62 "dtype_str": "", 

63 "device_cuda": "", 

64 "reset": "", 

65 }, 

66} 

67 

68OutputFormat = Literal["unicode", "latex", "ascii"] 

69 

70SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 

71 "latex": { 

72 "range": r"\mathcal{R}", 

73 "mean": r"\mu", 

74 "std": r"\sigma", 

75 "median": r"\tilde{x}", 

76 "distribution": r"\mathbb{P}", 

77 "distribution_log": r"\mathbb{P}_L", 

78 "nan_values": r"\text{NANvals}", 

79 "warning": "!!!", 

80 "requires_grad": r"\nabla", 

81 "true": r"\checkmark", 

82 "false": r"\times", 

83 }, 

84 "unicode": { 

85 "range": "R", 

86 "mean": "μ", 

87 "std": "σ", 

88 "median": "x̃", 

89 "distribution": "ℙ", 

90 "distribution_log": "ℙ˪", 

91 "nan_values": "NANvals", 

92 "warning": "🚨", 

93 "requires_grad": "∇", 

94 "true": "✓", 

95 "false": "✗", 

96 }, 

97 "ascii": { 

98 "range": "range", 

99 "mean": "mean", 

100 "std": "std", 

101 "median": "med", 

102 "distribution": "dist", 

103 "distribution_log": "dist_log", 

104 "nan_values": "NANvals", 

105 "warning": "!!!", 

106 "requires_grad": "requires_grad", 

107 "true": "1", 

108 "false": "0", 

109 }, 

110} 

111"Symbols for different formats" 

112 

113SPARK_CHARS: Dict[OutputFormat, List[str]] = { 

114 "unicode": list(" ▁▂▃▄▅▆▇█"), 

115 "ascii": list(" _.-~=#"), 

116 "latex": list(" ▁▂▃▄▅▆▇█"), 

117} 

118"characters for sparklines in different formats" 

119 

120 

121def array_info( 

122 A: Any, 

123 hist_bins: int = 5, 

124) -> Dict[str, Any]: 

125 """Extract statistical information from an array-like object. 

126 

127 # Parameters: 

128 - `A : array-like` 

129 Array to analyze (numpy array or torch tensor) 

130 

131 # Returns: 

132 - `Dict[str, Any]` 

133 Dictionary containing raw statistical information with numeric values 

134 """ 

135 result: Dict[str, Any] = { 

136 "is_tensor": None, 

137 "device": None, 

138 "requires_grad": None, 

139 "shape": None, 

140 "dtype": None, 

141 "size": None, 

142 "has_nans": None, 

143 "nan_count": None, 

144 "nan_percent": None, 

145 "min": None, 

146 "max": None, 

147 "range": None, 

148 "mean": None, 

149 "std": None, 

150 "median": None, 

151 "histogram": None, 

152 "bins": None, 

153 "status": None, 

154 } 

155 

156 # Check if it's a tensor by looking at its class name 

157 # This avoids importing torch directly 

158 A_type: str = type(A).__name__ 

159 result["is_tensor"] = A_type == "Tensor" 

160 

161 # Try to get device information if it's a tensor 

162 if result["is_tensor"]: 

163 try: 

164 result["device"] = str(getattr(A, "device", None)) 

165 except: # noqa: E722 

166 pass 

167 

168 # Convert to numpy array for calculations 

169 try: 

170 # For PyTorch tensors 

171 if result["is_tensor"]: 

172 # Check if tensor is on GPU 

173 is_cuda: bool = False 

174 try: 

175 is_cuda = bool(getattr(A, "is_cuda", False)) 

176 except: # noqa: E722 

177 pass 

178 

179 if is_cuda: 

180 try: 

181 # Try to get CPU tensor first 

182 cpu_tensor = getattr(A, "cpu", lambda: A)() 

183 except: # noqa: E722 

184 A_np = np.array([]) 

185 else: 

186 cpu_tensor = A 

187 try: 

188 # For CPU tensor, just detach and convert 

189 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 

190 A_np = getattr(detached, "numpy", lambda: np.array([]))() 

191 except: # noqa: E722 

192 A_np = np.array([]) 

193 else: 

194 # For numpy arrays and other array-like objects 

195 A_np = np.asarray(A) 

196 except: # noqa: E722 

197 A_np = np.array([]) 

198 

199 # Get basic information 

200 try: 

201 result["shape"] = A_np.shape 

202 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 

203 result["size"] = A_np.size 

204 result["requires_grad"] = getattr(A, "requires_grad", None) 

205 except: # noqa: E722 

206 pass 

207 

208 # If array is empty, return early 

209 if result["size"] == 0: 

210 result["status"] = "empty array" 

211 return result 

212 

213 # Flatten array for statistics if it's multi-dimensional 

214 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225 

215 try: 

216 if len(A_np.shape) > 1: 

217 A_flat = A_np.flatten() # type: ignore[assignment] 

218 else: 

219 A_flat = A_np # type: ignore[assignment] 

220 except: # noqa: E722 

221 A_flat = A_np # type: ignore[assignment] 

222 

223 # Check for NaN values 

224 try: 

225 nan_mask = np.isnan(A_flat) 

226 result["nan_count"] = np.sum(nan_mask) 

227 result["has_nans"] = result["nan_count"] > 0 

228 if result["size"] > 0: 

229 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 

230 except: # noqa: E722 

231 pass 

232 

233 # If all values are NaN, return early 

234 if result["has_nans"] and result["nan_count"] == result["size"]: 

235 result["status"] = "all NaN" 

236 return result 

237 

238 # Calculate statistics 

239 try: 

240 if result["has_nans"]: 

241 result["min"] = float(np.nanmin(A_flat)) 

242 result["max"] = float(np.nanmax(A_flat)) 

243 result["mean"] = float(np.nanmean(A_flat)) 

244 result["std"] = float(np.nanstd(A_flat)) 

245 result["median"] = float(np.nanmedian(A_flat)) 

246 result["range"] = (result["min"], result["max"]) 

247 

248 # Remove NaNs for histogram 

249 A_hist = A_flat[~nan_mask] 

250 else: 

251 result["min"] = float(np.min(A_flat)) 

252 result["max"] = float(np.max(A_flat)) 

253 result["mean"] = float(np.mean(A_flat)) 

254 result["std"] = float(np.std(A_flat)) 

255 result["median"] = float(np.median(A_flat)) 

256 result["range"] = (result["min"], result["max"]) 

257 

258 A_hist = A_flat 

259 

260 # Calculate histogram data for sparklines 

261 if A_hist.size > 0: 

262 try: 

263 # TODO: handle bool tensors correctly 

264 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. 

265 hist, bins = np.histogram(A_hist, bins=hist_bins) 

266 result["histogram"] = hist 

267 result["bins"] = bins 

268 except: # noqa: E722 

269 pass 

270 

271 result["status"] = "ok" 

272 except Exception as e: 

273 result["status"] = f"error: {str(e)}" 

274 

275 return result 

276 

277 

278def generate_sparkline( 

279 histogram: np.ndarray, 

280 format: Literal["unicode", "latex", "ascii"] = "unicode", 

281 log_y: Optional[bool] = None, 

282) -> tuple[str, bool]: 

283 """Generate a sparkline visualization of the histogram. 

284 

285 # Parameters: 

286 - `histogram : np.ndarray` 

287 Histogram data 

288 - `format : Literal["unicode", "latex", "ascii"]` 

289 Output format (defaults to `"unicode"`) 

290 - `log_y : bool|None` 

291 Whether to use logarithmic y-scale. `None` for automatic detection 

292 (defaults to `None`) 

293 

294 # Returns: 

295 - `tuple[str, bool]` 

296 Sparkline visualization and whether log scale was used 

297 """ 

298 if histogram is None or len(histogram) == 0: 

299 return "", False 

300 

301 # Get the appropriate character set 

302 chars: List[str] 

303 if format in SPARK_CHARS: 

304 chars = SPARK_CHARS[format] 

305 else: 

306 chars = SPARK_CHARS["ascii"] 

307 

308 # automatic detection of log_y 

309 if log_y is None: 

310 # we bin the histogram values to the number of levels in our sparkline characters 

311 hist_hist = np.histogram(histogram, bins=len(chars))[0] 

312 # if every bin except the smallest (first) and largest (last) is empty, 

313 # then we should use the log scale. if those bins are nonempty, keep the linear scale 

314 if hist_hist[1:-1].max() > 0: 

315 log_y = False 

316 else: 

317 log_y = True 

318 

319 # Handle log scale 

320 if log_y: 

321 # Add small value to avoid log(0) 

322 hist_data = np.log1p(histogram) 

323 else: 

324 hist_data = histogram 

325 

326 # Normalize to character set range 

327 if hist_data.max() > 0: 

328 normalized = hist_data / hist_data.max() * (len(chars) - 1) 

329 else: 

330 normalized = np.zeros_like(hist_data) 

331 

332 # Convert to characters 

333 spark = "" 

334 for val in normalized: 

335 idx = round(val) 

336 spark += chars[idx] 

337 

338 return spark, log_y 

339 

340 

341DEFAULT_SETTINGS: Dict[str, Any] = dict( 

342 fmt="unicode", 

343 precision=2, 

344 stats=True, 

345 shape=True, 

346 dtype=True, 

347 device=True, 

348 requires_grad=True, 

349 sparkline=False, 

350 sparkline_bins=5, 

351 sparkline_logy=None, 

352 colored=False, 

353 as_list=False, 

354 eq_char="=", 

355) 

356 

357 

358def apply_color( 

359 text: str, color_key: str, colors: Dict[str, str], using_tex: bool 

360) -> str: 

361 if using_tex: 

362 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

363 else: 

364 return ( 

365 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text 

366 ) 

367 

368 

369def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: 

370 """Colorize dtype string with specific colors for torch and type names.""" 

371 

372 # Handle torch prefix 

373 type_part: str = dtype_str 

374 prefix_part: Optional[str] = None 

375 if "torch." in dtype_str: 

376 parts = dtype_str.split("torch.") 

377 if len(parts) == 2: 

378 prefix_part = apply_color("torch", "torch", colors, using_tex) 

379 type_part = parts[1] 

380 

381 # Handle type coloring 

382 color_key: str = "dtype" 

383 if "bool" in dtype_str.lower(): 

384 color_key = "dtype_bool" 

385 elif "int" in dtype_str.lower(): 

386 color_key = "dtype_int" 

387 elif "float" in dtype_str.lower(): 

388 color_key = "dtype_float" 

389 

390 type_colored: str = apply_color(type_part, color_key, colors, using_tex) 

391 

392 if prefix_part: 

393 return f"{prefix_part}.{type_colored}" 

394 else: 

395 return type_colored 

396 

397 

398def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str: 

399 """Format shape with proper coloring for both 1D and multi-D arrays.""" 

400 

401 def apply_color(text: str, color_key: str) -> str: 

402 if using_tex: 

403 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

404 else: 

405 return ( 

406 f"{colors[color_key]}{text}{colors['reset']}" 

407 if colors[color_key] 

408 else text 

409 ) 

410 

411 if len(shape_val) == 1: 

412 # For 1D arrays, still color the dimension value 

413 return apply_color(str(shape_val[0]), "shape") 

414 else: 

415 # For multi-D arrays, color each dimension 

416 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")" 

417 

418 

419def format_device_colored( 

420 device_str: str, colors: Dict[str, str], using_tex: bool 

421) -> str: 

422 """Format device string with CUDA highlighting.""" 

423 

424 def apply_color(text: str, color_key: str) -> str: 

425 if using_tex: 

426 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

427 else: 

428 return ( 

429 f"{colors[color_key]}{text}{colors['reset']}" 

430 if colors[color_key] 

431 else text 

432 ) 

433 

434 if "cuda" in device_str.lower(): 

435 return apply_color(device_str, "device_cuda") 

436 else: 

437 return apply_color(device_str, "device") 

438 

439 

440class _UseDefaultType: 

441 pass 

442 

443 

444_USE_DEFAULT = _UseDefaultType() 

445 

446 

447@overload 

448def array_summary( 

449 array: Any, 

450 as_list: Literal[True], 

451 **kwargs, 

452) -> List[str]: ... 

453@overload 

454def array_summary( 

455 array: Any, 

456 as_list: Literal[False], 

457 **kwargs, 

458) -> str: ... 

459def array_summary( # type: ignore[misc] 

460 array, 

461 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 

462 precision: int = _USE_DEFAULT, # type: ignore[assignment] 

463 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 

464 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 

465 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 

466 device: bool = _USE_DEFAULT, # type: ignore[assignment] 

467 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 

468 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 

469 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 

470 sparkline_logy: Optional[bool] = _USE_DEFAULT, # type: ignore[assignment] 

471 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 

472 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 

473 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 

474) -> Union[str, List[str]]: 

475 """Format array information into a readable summary. 

476 

477 # Parameters: 

478 - `array` 

479 array-like object (numpy array or torch tensor) 

480 - `precision : int` 

481 Decimal places (defaults to `2`) 

482 - `format : Literal["unicode", "latex", "ascii"]` 

483 Output format (defaults to `{default_fmt}`) 

484 - `stats : bool` 

485 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 

486 - `shape : bool` 

487 Whether to include shape info (defaults to `True`) 

488 - `dtype : bool` 

489 Whether to include dtype info (defaults to `True`) 

490 - `device : bool` 

491 Whether to include device info for torch tensors (defaults to `True`) 

492 - `requires_grad : bool` 

493 Whether to include requires_grad info for torch tensors (defaults to `True`) 

494 - `sparkline : bool` 

495 Whether to include a sparkline visualization (defaults to `False`) 

496 - `sparkline_width : int` 

497 Width of the sparkline (defaults to `20`) 

498 - `sparkline_logy : bool|None` 

499 Whether to use logarithmic y-scale for sparkline (defaults to `None`) 

500 - `colored : bool` 

501 Whether to add color to output (defaults to `False`) 

502 - `as_list : bool` 

503 Whether to return as list of strings instead of joined string (defaults to `False`) 

504 

505 # Returns: 

506 - `Union[str, List[str]]` 

507 Formatted statistical summary, either as string or list of strings 

508 """ 

509 if fmt is _USE_DEFAULT: 

510 fmt = DEFAULT_SETTINGS["fmt"] 

511 if precision is _USE_DEFAULT: 

512 precision = DEFAULT_SETTINGS["precision"] 

513 if stats is _USE_DEFAULT: 

514 stats = DEFAULT_SETTINGS["stats"] 

515 if shape is _USE_DEFAULT: 

516 shape = DEFAULT_SETTINGS["shape"] 

517 if dtype is _USE_DEFAULT: 

518 dtype = DEFAULT_SETTINGS["dtype"] 

519 if device is _USE_DEFAULT: 

520 device = DEFAULT_SETTINGS["device"] 

521 if requires_grad is _USE_DEFAULT: 

522 requires_grad = DEFAULT_SETTINGS["requires_grad"] 

523 if sparkline is _USE_DEFAULT: 

524 sparkline = DEFAULT_SETTINGS["sparkline"] 

525 if sparkline_bins is _USE_DEFAULT: 

526 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 

527 if sparkline_logy is _USE_DEFAULT: 

528 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 

529 if colored is _USE_DEFAULT: 

530 colored = DEFAULT_SETTINGS["colored"] 

531 if as_list is _USE_DEFAULT: 

532 as_list = DEFAULT_SETTINGS["as_list"] 

533 if eq_char is _USE_DEFAULT: 

534 eq_char = DEFAULT_SETTINGS["eq_char"] 

535 

536 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 

537 result_parts: List[str] = [] 

538 using_tex: bool = fmt == "latex" 

539 

540 # Set color scheme based on format and colored flag 

541 colors: Dict[str, str] 

542 if colored: 

543 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 

544 else: 

545 colors = COLORS["none"] 

546 

547 # Get symbols for the current format 

548 symbols: Dict[str, str] = SYMBOLS[fmt] 

549 

550 # Helper function to colorize text 

551 def colorize(text: str, color_key: str) -> str: 

552 if using_tex: 

553 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

554 else: 

555 return ( 

556 f"{colors[color_key]}{text}{colors['reset']}" 

557 if colors[color_key] 

558 else text 

559 ) 

560 

561 # Check if dtype is integer type 

562 dtype_str: str = array_data.get("dtype", "") 

563 is_int_dtype: bool = any( 

564 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] 

565 ) 

566 

567 # Format string for numbers 

568 float_fmt: str = f".{precision}f" 

569 

570 # Handle error status or empty array 

571 if ( 

572 array_data["status"] in ["empty array", "all NaN", "unknown"] 

573 or array_data["size"] == 0 

574 ): 

575 status = array_data["status"] 

576 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 

577 else: 

578 # Add NaN warning at the beginning if there are NaNs 

579 if array_data["has_nans"]: 

580 _percent: str = "\\%" if using_tex else "%" 

581 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 

582 result_parts.append(colorize(nan_str, "warning")) 

583 

584 # Statistics 

585 if stats: 

586 for stat_key in ["mean", "std", "median"]: 

587 if array_data[stat_key] is not None: 

588 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 

589 stat_colored: str = colorize(stat_str, stat_key) 

590 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 

591 

592 # Range (min, max) 

593 if array_data["range"] is not None: 

594 min_val, max_val = array_data["range"] 

595 if is_int_dtype: 

596 min_str: str = f"{int(min_val):d}" 

597 max_str: str = f"{int(max_val):d}" 

598 else: 

599 min_str = f"{min_val:{float_fmt}}" 

600 max_str = f"{max_val:{float_fmt}}" 

601 min_colored: str = colorize(min_str, "range") 

602 max_colored: str = colorize(max_str, "range") 

603 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 

604 result_parts.append(range_str) 

605 

606 # Add sparkline if requested 

607 if sparkline and array_data["histogram"] is not None: 

608 # this should return whether log_y is used or not and then we set the symbol accordingly 

609 spark, used_log = generate_sparkline( 

610 array_data["histogram"], 

611 format=fmt, 

612 log_y=sparkline_logy, 

613 ) 

614 if spark: 

615 spark_colored = colorize(spark, "sparkline") 

616 dist_symbol = ( 

617 symbols["distribution_log"] if used_log else symbols["distribution"] 

618 ) 

619 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") 

620 

621 # Add shape if requested 

622 if shape and array_data["shape"]: 

623 shape_val = array_data["shape"] 

624 shape_str = format_shape_colored(shape_val, colors, using_tex) 

625 result_parts.append(f"shape{eq_char}{shape_str}") 

626 

627 # Add dtype if requested 

628 if dtype and array_data["dtype"]: 

629 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) 

630 result_parts.append(f"dtype={dtype_colored}") 

631 

632 # Add device if requested and it's a tensor with device info 

633 if device and array_data["is_tensor"] and array_data["device"]: 

634 device_colored = format_device_colored(array_data["device"], colors, using_tex) 

635 result_parts.append(f"device{eq_char}{device_colored}") 

636 

637 # Add gradient info 

638 if requires_grad and array_data["is_tensor"]: 

639 bool_req_grad_symb: str = ( 

640 symbols["true"] if array_data["requires_grad"] else symbols["false"] 

641 ) 

642 result_parts.append( 

643 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 

644 ) 

645 

646 # Return as list if requested, otherwise join with spaces 

647 if as_list: 

648 return result_parts 

649 else: 

650 joinchar: str = r" \quad " if using_tex else " " 

651 return joinchar.join(result_parts)