Coverage for muutils/tensor_info.py: 88%

199 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 18:27 -0600

1import numpy as np 

2from typing import Union, Any, Literal, List, Dict, overload 

3 

4# Global color definitions 

5COLORS: Dict[str, Dict[str, str]] = { 

6 "latex": { 

7 "range": r"\textcolor{purple}", 

8 "mean": r"\textcolor{teal}", 

9 "std": r"\textcolor{orange}", 

10 "median": r"\textcolor{green}", 

11 "warning": r"\textcolor{red}", 

12 "shape": r"\textcolor{magenta}", 

13 "dtype": r"\textcolor{gray}", 

14 "device": r"\textcolor{gray}", 

15 "requires_grad": r"\textcolor{gray}", 

16 "sparkline": r"\textcolor{blue}", 

17 "reset": "", 

18 }, 

19 "terminal": { 

20 "range": "\033[35m", # purple 

21 "mean": "\033[36m", # cyan/teal 

22 "std": "\033[33m", # yellow/orange 

23 "median": "\033[32m", # green 

24 "warning": "\033[31m", # red 

25 "shape": "\033[95m", # bright magenta 

26 "dtype": "\033[90m", # gray 

27 "device": "\033[90m", # gray 

28 "requires_grad": "\033[90m", # gray 

29 "sparkline": "\033[34m", # blue 

30 "reset": "\033[0m", 

31 }, 

32 "none": { 

33 "range": "", 

34 "mean": "", 

35 "std": "", 

36 "median": "", 

37 "warning": "", 

38 "shape": "", 

39 "dtype": "", 

40 "device": "", 

41 "requires_grad": "", 

42 "sparkline": "", 

43 "reset": "", 

44 }, 

45} 

46 

47OutputFormat = Literal["unicode", "latex", "ascii"] 

48 

49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 

50 "latex": { 

51 "range": r"\mathcal{R}", 

52 "mean": r"\mu", 

53 "std": r"\sigma", 

54 "median": r"\tilde{x}", 

55 "distribution": r"\mathbb{P}", 

56 "nan_values": r"\text{NANvals}", 

57 "warning": "!!!", 

58 "requires_grad": r"\nabla", 

59 "true": r"\checkmark", 

60 "false": r"\times", 

61 }, 

62 "unicode": { 

63 "range": "R", 

64 "mean": "μ", 

65 "std": "σ", 

66 "median": "x̃", 

67 "distribution": "ℙ", 

68 "nan_values": "NANvals", 

69 "warning": "🚨", 

70 "requires_grad": "∇", 

71 "true": "✓", 

72 "false": "✗", 

73 }, 

74 "ascii": { 

75 "range": "range", 

76 "mean": "mean", 

77 "std": "std", 

78 "median": "med", 

79 "distribution": "dist", 

80 "nan_values": "NANvals", 

81 "warning": "!!!", 

82 "requires_grad": "requires_grad", 

83 "true": "1", 

84 "false": "0", 

85 }, 

86} 

87"Symbols for different formats" 

88 

89SPARK_CHARS: Dict[OutputFormat, List[str]] = { 

90 "unicode": list(" ▁▂▃▄▅▆▇█"), 

91 "ascii": list(" _.-~=#"), 

92 "latex": list(" ▁▂▃▄▅▆▇█"), 

93} 

94"characters for sparklines in different formats" 

95 

96 

97def array_info( 

98 A: Any, 

99 hist_bins: int = 5, 

100) -> Dict[str, Any]: 

101 """Extract statistical information from an array-like object. 

102 

103 # Parameters: 

104 - `A : array-like` 

105 Array to analyze (numpy array or torch tensor) 

106 

107 # Returns: 

108 - `Dict[str, Any]` 

109 Dictionary containing raw statistical information with numeric values 

110 """ 

111 result: Dict[str, Any] = { 

112 "is_tensor": None, 

113 "device": None, 

114 "requires_grad": None, 

115 "shape": None, 

116 "dtype": None, 

117 "size": None, 

118 "has_nans": None, 

119 "nan_count": None, 

120 "nan_percent": None, 

121 "min": None, 

122 "max": None, 

123 "range": None, 

124 "mean": None, 

125 "std": None, 

126 "median": None, 

127 "histogram": None, 

128 "bins": None, 

129 "status": None, 

130 } 

131 

132 # Check if it's a tensor by looking at its class name 

133 # This avoids importing torch directly 

134 A_type: str = type(A).__name__ 

135 result["is_tensor"] = A_type == "Tensor" 

136 

137 # Try to get device information if it's a tensor 

138 if result["is_tensor"]: 

139 try: 

140 result["device"] = str(getattr(A, "device", None)) 

141 except: # noqa: E722 

142 pass 

143 

144 # Convert to numpy array for calculations 

145 try: 

146 # For PyTorch tensors 

147 if result["is_tensor"]: 

148 # Check if tensor is on GPU 

149 is_cuda: bool = False 

150 try: 

151 is_cuda = bool(getattr(A, "is_cuda", False)) 

152 except: # noqa: E722 

153 pass 

154 

155 if is_cuda: 

156 try: 

157 # Try to get CPU tensor first 

158 cpu_tensor = getattr(A, "cpu", lambda: A)() 

159 except: # noqa: E722 

160 A_np = np.array([]) 

161 else: 

162 cpu_tensor = A 

163 try: 

164 # For CPU tensor, just detach and convert 

165 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 

166 A_np = getattr(detached, "numpy", lambda: np.array([]))() 

167 except: # noqa: E722 

168 A_np = np.array([]) 

169 else: 

170 # For numpy arrays and other array-like objects 

171 A_np = np.asarray(A) 

172 except: # noqa: E722 

173 A_np = np.array([]) 

174 

175 # Get basic information 

176 try: 

177 result["shape"] = A_np.shape 

178 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 

179 result["size"] = A_np.size 

180 result["requires_grad"] = getattr(A, "requires_grad", None) 

181 except: # noqa: E722 

182 pass 

183 

184 # If array is empty, return early 

185 if result["size"] == 0: 

186 result["status"] = "empty array" 

187 return result 

188 

189 # Flatten array for statistics if it's multi-dimensional 

190 try: 

191 if len(A_np.shape) > 1: 

192 A_flat = A_np.flatten() 

193 else: 

194 A_flat = A_np 

195 except: # noqa: E722 

196 A_flat = A_np 

197 

198 # Check for NaN values 

199 try: 

200 nan_mask = np.isnan(A_flat) 

201 result["nan_count"] = np.sum(nan_mask) 

202 result["has_nans"] = result["nan_count"] > 0 

203 if result["size"] > 0: 

204 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 

205 except: # noqa: E722 

206 pass 

207 

208 # If all values are NaN, return early 

209 if result["has_nans"] and result["nan_count"] == result["size"]: 

210 result["status"] = "all NaN" 

211 return result 

212 

213 # Calculate statistics 

214 try: 

215 if result["has_nans"]: 

216 result["min"] = float(np.nanmin(A_flat)) 

217 result["max"] = float(np.nanmax(A_flat)) 

218 result["mean"] = float(np.nanmean(A_flat)) 

219 result["std"] = float(np.nanstd(A_flat)) 

220 result["median"] = float(np.nanmedian(A_flat)) 

221 result["range"] = (result["min"], result["max"]) 

222 

223 # Remove NaNs for histogram 

224 A_hist = A_flat[~nan_mask] 

225 else: 

226 result["min"] = float(np.min(A_flat)) 

227 result["max"] = float(np.max(A_flat)) 

228 result["mean"] = float(np.mean(A_flat)) 

229 result["std"] = float(np.std(A_flat)) 

230 result["median"] = float(np.median(A_flat)) 

231 result["range"] = (result["min"], result["max"]) 

232 

233 A_hist = A_flat 

234 

235 # Calculate histogram data for sparklines 

236 if A_hist.size > 0: 

237 try: 

238 hist, bins = np.histogram(A_hist, bins=hist_bins) 

239 result["histogram"] = hist 

240 result["bins"] = bins 

241 except: # noqa: E722 

242 pass 

243 

244 result["status"] = "ok" 

245 except Exception as e: 

246 result["status"] = f"error: {str(e)}" 

247 

248 return result 

249 

250 

251def generate_sparkline( 

252 histogram: np.ndarray, 

253 format: Literal["unicode", "latex", "ascii"] = "unicode", 

254 log_y: bool = False, 

255) -> str: 

256 """Generate a sparkline visualization of the histogram. 

257 

258 # Parameters: 

259 - `histogram : np.ndarray` 

260 Histogram data 

261 - `format : Literal["unicode", "latex", "ascii"]` 

262 Output format (defaults to `"unicode"`) 

263 - `log_y : bool` 

264 Whether to use logarithmic y-scale (defaults to `False`) 

265 

266 # Returns: 

267 - `str` 

268 Sparkline visualization 

269 """ 

270 if histogram is None or len(histogram) == 0: 

271 return "" 

272 

273 # Get the appropriate character set 

274 if format in SPARK_CHARS: 

275 chars = SPARK_CHARS[format] 

276 else: 

277 chars = SPARK_CHARS["ascii"] 

278 

279 # Handle log scale 

280 if log_y: 

281 # Add small value to avoid log(0) 

282 hist_data = np.log1p(histogram) 

283 else: 

284 hist_data = histogram 

285 

286 # Normalize to character set range 

287 if hist_data.max() > 0: 

288 normalized = hist_data / hist_data.max() * (len(chars) - 1) 

289 else: 

290 normalized = np.zeros_like(hist_data) 

291 

292 # Convert to characters 

293 spark = "" 

294 for val in normalized: 

295 idx = int(val) 

296 spark += chars[idx] 

297 

298 return spark 

299 

300 

301DEFAULT_SETTINGS: Dict[str, Any] = dict( 

302 fmt="unicode", 

303 precision=2, 

304 stats=True, 

305 shape=True, 

306 dtype=True, 

307 device=True, 

308 requires_grad=True, 

309 sparkline=False, 

310 sparkline_bins=5, 

311 sparkline_logy=False, 

312 colored=False, 

313 as_list=False, 

314 eq_char="=", 

315) 

316 

317 

318class _UseDefaultType: 

319 pass 

320 

321 

322_USE_DEFAULT = _UseDefaultType() 

323 

324 

325@overload 

326def array_summary( 

327 as_list: Literal[True], 

328 **kwargs, 

329) -> List[str]: ... 

330@overload 

331def array_summary( 

332 as_list: Literal[False], 

333 **kwargs, 

334) -> str: ... 

335def array_summary( # type: ignore[misc] 

336 array, 

337 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 

338 precision: int = _USE_DEFAULT, # type: ignore[assignment] 

339 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 

340 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 

341 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 

342 device: bool = _USE_DEFAULT, # type: ignore[assignment] 

343 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 

344 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 

345 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 

346 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment] 

347 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 

348 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 

349 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 

350) -> Union[str, List[str]]: 

351 """Format array information into a readable summary. 

352 

353 # Parameters: 

354 - `array` 

355 array-like object (numpy array or torch tensor) 

356 - `precision : int` 

357 Decimal places (defaults to `2`) 

358 - `format : Literal["unicode", "latex", "ascii"]` 

359 Output format (defaults to `{default_fmt}`) 

360 - `stats : bool` 

361 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 

362 - `shape : bool` 

363 Whether to include shape info (defaults to `True`) 

364 - `dtype : bool` 

365 Whether to include dtype info (defaults to `True`) 

366 - `device : bool` 

367 Whether to include device info for torch tensors (defaults to `True`) 

368 - `requires_grad : bool` 

369 Whether to include requires_grad info for torch tensors (defaults to `True`) 

370 - `sparkline : bool` 

371 Whether to include a sparkline visualization (defaults to `False`) 

372 - `sparkline_width : int` 

373 Width of the sparkline (defaults to `20`) 

374 - `sparkline_logy : bool` 

375 Whether to use logarithmic y-scale for sparkline (defaults to `False`) 

376 - `colored : bool` 

377 Whether to add color to output (defaults to `False`) 

378 - `as_list : bool` 

379 Whether to return as list of strings instead of joined string (defaults to `False`) 

380 

381 # Returns: 

382 - `Union[str, List[str]]` 

383 Formatted statistical summary, either as string or list of strings 

384 """ 

385 if fmt is _USE_DEFAULT: 

386 fmt = DEFAULT_SETTINGS["fmt"] 

387 if precision is _USE_DEFAULT: 

388 precision = DEFAULT_SETTINGS["precision"] 

389 if stats is _USE_DEFAULT: 

390 stats = DEFAULT_SETTINGS["stats"] 

391 if shape is _USE_DEFAULT: 

392 shape = DEFAULT_SETTINGS["shape"] 

393 if dtype is _USE_DEFAULT: 

394 dtype = DEFAULT_SETTINGS["dtype"] 

395 if device is _USE_DEFAULT: 

396 device = DEFAULT_SETTINGS["device"] 

397 if requires_grad is _USE_DEFAULT: 

398 requires_grad = DEFAULT_SETTINGS["requires_grad"] 

399 if sparkline is _USE_DEFAULT: 

400 sparkline = DEFAULT_SETTINGS["sparkline"] 

401 if sparkline_bins is _USE_DEFAULT: 

402 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 

403 if sparkline_logy is _USE_DEFAULT: 

404 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 

405 if colored is _USE_DEFAULT: 

406 colored = DEFAULT_SETTINGS["colored"] 

407 if as_list is _USE_DEFAULT: 

408 as_list = DEFAULT_SETTINGS["as_list"] 

409 if eq_char is _USE_DEFAULT: 

410 eq_char = DEFAULT_SETTINGS["eq_char"] 

411 

412 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 

413 result_parts: List[str] = [] 

414 using_tex: bool = fmt == "latex" 

415 

416 # Set color scheme based on format and colored flag 

417 colors: Dict[str, str] 

418 if colored: 

419 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 

420 else: 

421 colors = COLORS["none"] 

422 

423 # Get symbols for the current format 

424 symbols: Dict[str, str] = SYMBOLS[fmt] 

425 

426 # Helper function to colorize text 

427 def colorize(text: str, color_key: str) -> str: 

428 if using_tex: 

429 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text 

430 else: 

431 return ( 

432 f"{colors[color_key]}{text}{colors['reset']}" 

433 if colors[color_key] 

434 else text 

435 ) 

436 

437 # Format string for numbers 

438 float_fmt: str = f".{precision}f" 

439 

440 # Handle error status or empty array 

441 if ( 

442 array_data["status"] in ["empty array", "all NaN", "unknown"] 

443 or array_data["size"] == 0 

444 ): 

445 status = array_data["status"] 

446 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 

447 else: 

448 # Add NaN warning at the beginning if there are NaNs 

449 if array_data["has_nans"]: 

450 _percent: str = "\\%" if using_tex else "%" 

451 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 

452 result_parts.append(colorize(nan_str, "warning")) 

453 

454 # Statistics 

455 if stats: 

456 for stat_key in ["mean", "std", "median"]: 

457 if array_data[stat_key] is not None: 

458 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 

459 stat_colored: str = colorize(stat_str, stat_key) 

460 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 

461 

462 # Range (min, max) 

463 if array_data["range"] is not None: 

464 min_val, max_val = array_data["range"] 

465 min_str: str = f"{min_val:{float_fmt}}" 

466 max_str: str = f"{max_val:{float_fmt}}" 

467 min_colored: str = colorize(min_str, "range") 

468 max_colored: str = colorize(max_str, "range") 

469 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 

470 result_parts.append(range_str) 

471 

472 # Add sparkline if requested 

473 if sparkline and array_data["histogram"] is not None: 

474 spark = generate_sparkline( 

475 array_data["histogram"], format=fmt, log_y=sparkline_logy 

476 ) 

477 if spark: 

478 spark_colored = colorize(spark, "sparkline") 

479 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|") 

480 

481 # Add shape if requested 

482 if shape and array_data["shape"]: 

483 shape_val = array_data["shape"] 

484 if len(shape_val) == 1: 

485 shape_str = str(shape_val[0]) 

486 else: 

487 shape_str = ( 

488 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")" 

489 ) 

490 result_parts.append(f"shape{eq_char}{shape_str}") 

491 

492 # Add dtype if requested 

493 if dtype and array_data["dtype"]: 

494 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype")) 

495 

496 # Add device if requested and it's a tensor with device info 

497 if device and array_data["is_tensor"] and array_data["device"]: 

498 result_parts.append( 

499 colorize(f"device{eq_char}{array_data['device']}", "device") 

500 ) 

501 

502 # Add gradient info 

503 if requires_grad and array_data["is_tensor"]: 

504 bool_req_grad_symb: str = ( 

505 symbols["true"] if array_data["requires_grad"] else symbols["false"] 

506 ) 

507 result_parts.append( 

508 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 

509 ) 

510 

511 # Return as list if requested, otherwise join with spaces 

512 if as_list: 

513 return result_parts 

514 else: 

515 joinchar: str = r" \quad " if using_tex else " " 

516 return joinchar.join(result_parts)