Coverage for muutils / tensor_info.py: 89%

266 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1"get metadata about a tensor, mostly for `muutils.dbg`" 

2 

3from __future__ import annotations 

4 

5import numpy as np 

6from typing import Union, Any, Literal, List, Dict, overload, Optional, TYPE_CHECKING 

7 

8if TYPE_CHECKING: 

9 from typing import TypedDict 

10else: 

11 try: 

12 from typing import TypedDict 

13 except ImportError: 

14 from typing_extensions import TypedDict 

15 

16# Global color definitions 

17COLORS: Dict[str, Dict[str, str]] = { 

18 "latex": { 

19 "range": r"\textcolor{purple}", 

20 "mean": r"\textcolor{teal}", 

21 "std": r"\textcolor{orange}", 

22 "median": r"\textcolor{green}", 

23 "warning": r"\textcolor{red}", 

24 "shape": r"\textcolor{magenta}", 

25 "dtype": r"\textcolor{gray}", 

26 "device": r"\textcolor{gray}", 

27 "requires_grad": r"\textcolor{gray}", 

28 "sparkline": r"\textcolor{blue}", 

29 "torch": r"\textcolor{orange}", 

30 "dtype_bool": r"\textcolor{gray}", 

31 "dtype_int": r"\textcolor{blue}", 

32 "dtype_float": r"\textcolor{red!70}", # 70% red intensity 

33 "dtype_str": r"\textcolor{red}", 

34 "device_cuda": r"\textcolor{green}", 

35 "reset": "", 

36 }, 

37 "terminal": { 

38 "range": "\033[35m", # purple 

39 "mean": "\033[36m", # cyan/teal 

40 "std": "\033[33m", # yellow/orange 

41 "median": "\033[32m", # green 

42 "warning": "\033[31m", # red 

43 "shape": "\033[95m", # bright magenta 

44 "dtype": "\033[90m", # gray 

45 "device": "\033[90m", # gray 

46 "requires_grad": "\033[90m", # gray 

47 "sparkline": "\033[34m", # blue 

48 "torch": "\033[38;5;208m", # bright orange 

49 "dtype_bool": "\033[38;5;245m", # medium grey 

50 "dtype_int": "\033[38;5;39m", # bright blue 

51 "dtype_float": "\033[38;5;167m", # softer red/coral 

52 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green 

53 "reset": "\033[0m", 

54 }, 

55 "none": { 

56 "range": "", 

57 "mean": "", 

58 "std": "", 

59 "median": "", 

60 "warning": "", 

61 "shape": "", 

62 "dtype": "", 

63 "device": "", 

64 "requires_grad": "", 

65 "sparkline": "", 

66 "torch": "", 

67 "dtype_bool": "", 

68 "dtype_int": "", 

69 "dtype_float": "", 

70 "dtype_str": "", 

71 "device_cuda": "", 

72 "reset": "", 

73 }, 

74} 

75 

76OutputFormat = Literal["unicode", "latex", "ascii"] 

77 

78 

79class ArraySummarySettings(TypedDict): 

80 """Type definition for array_summary default settings.""" 

81 

82 fmt: OutputFormat 

83 precision: int 

84 stats: bool 

85 shape: bool 

86 dtype: bool 

87 device: bool 

88 requires_grad: bool 

89 sparkline: bool 

90 sparkline_bins: int 

91 sparkline_logy: Optional[bool] 

92 colored: bool 

93 as_list: bool 

94 eq_char: str 

95 

96 

97SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 

98 "latex": { 

99 "range": r"\mathcal{R}", 

100 "mean": r"\mu", 

101 "std": r"\sigma", 

102 "median": r"\tilde{x}", 

103 "distribution": r"\mathbb{P}", 

104 "distribution_log": r"\mathbb{P}_L", 

105 "nan_values": r"\text{NANvals}", 

106 "warning": "!!!", 

107 "requires_grad": r"\nabla", 

108 "true": r"\checkmark", 

109 "false": r"\times", 

110 }, 

111 "unicode": { 

112 "range": "R", 

113 "mean": "μ", 

114 "std": "σ", 

115 "median": "x̃", 

116 "distribution": "ℙ", 

117 "distribution_log": "ℙ˪", 

118 "nan_values": "NANvals", 

119 "warning": "🚨", 

120 "requires_grad": "∇", 

121 "true": "✓", 

122 "false": "✗", 

123 }, 

124 "ascii": { 

125 "range": "range", 

126 "mean": "mean", 

127 "std": "std", 

128 "median": "med", 

129 "distribution": "dist", 

130 "distribution_log": "dist_log", 

131 "nan_values": "NANvals", 

132 "warning": "!!!", 

133 "requires_grad": "requires_grad", 

134 "true": "1", 

135 "false": "0", 

136 }, 

137} 

138"Symbols for different formats" 

139 

140SPARK_CHARS: Dict[OutputFormat, List[str]] = { 

141 "unicode": list(" ▁▂▃▄▅▆▇█"), 

142 "ascii": list(" _.-~=#"), 

143 "latex": list(" ▁▂▃▄▅▆▇█"), 

144} 

145"characters for sparklines in different formats" 

146 

147 

148def array_info( 

149 A: Any, 

150 hist_bins: int = 5, 

151) -> Dict[str, Any]: 

152 """Extract statistical information from an array-like object. 

153 

154 # Parameters: 

155 - `A : array-like` 

156 Array to analyze (numpy array or torch tensor) 

157 

158 # Returns: 

159 - `Dict[str, Any]` 

160 Dictionary containing raw statistical information with numeric values 

161 """ 

162 result: Dict[str, Any] = { 

163 "is_tensor": None, 

164 "device": None, 

165 "requires_grad": None, 

166 "shape": None, 

167 "dtype": None, 

168 "size": None, 

169 "has_nans": None, 

170 "nan_count": None, 

171 "nan_percent": None, 

172 "min": None, 

173 "max": None, 

174 "range": None, 

175 "mean": None, 

176 "std": None, 

177 "median": None, 

178 "histogram": None, 

179 "bins": None, 

180 "status": None, 

181 } 

182 

183 # Check if it's a tensor by looking at its class name 

184 # This avoids importing torch directly 

185 A_type: str = type(A).__name__ 

186 result["is_tensor"] = A_type == "Tensor" 

187 

188 # Try to get device information if it's a tensor 

189 if result["is_tensor"]: 

190 try: 

191 result["device"] = str(getattr(A, "device", None)) 

192 except: # noqa: E722 

193 pass 

194 

195 # Convert to numpy array for calculations 

196 try: 

197 # For PyTorch tensors 

198 if result["is_tensor"]: 

199 # Check if tensor is on GPU 

200 is_cuda: bool = False 

201 try: 

202 is_cuda = bool(getattr(A, "is_cuda", False)) 

203 except: # noqa: E722 

204 pass 

205 

206 if is_cuda: 

207 try: 

208 # Try to get CPU tensor first 

209 cpu_tensor = getattr(A, "cpu", lambda: A)() 

210 except: # noqa: E722 

211 A_np = np.array([]) 

212 else: 

213 cpu_tensor = A 

214 try: 

215 # For CPU tensor, just detach and convert 

216 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() # pyright: ignore[reportPossiblyUnboundVariable] 

217 A_np = getattr(detached, "numpy", lambda: np.array([]))() 

218 except: # noqa: E722 

219 A_np = np.array([]) 

220 else: 

221 # For numpy arrays and other array-like objects 

222 A_np = np.asarray(A) 

223 except: # noqa: E722 

224 A_np = np.array([]) 

225 

226 # Get basic information 

227 try: 

228 result["shape"] = A_np.shape 

229 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 

230 result["size"] = A_np.size 

231 result["requires_grad"] = getattr(A, "requires_grad", None) 

232 except: # noqa: E722 

233 pass 

234 

235 # If array is empty, return early 

236 if result["size"] == 0: 

237 result["status"] = "empty array" 

238 return result 

239 

240 # Flatten array for statistics if it's multi-dimensional 

241 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225 

242 try: 

243 if len(A_np.shape) > 1: 

244 A_flat = A_np.flatten() # type: ignore[assignment] 

245 else: 

246 A_flat = A_np # type: ignore[assignment] 

247 except: # noqa: E722 

248 A_flat = A_np # type: ignore[assignment] 

249 

250 # Check for NaN values 

251 try: 

252 nan_mask = np.isnan(A_flat) 

253 result["nan_count"] = np.sum(nan_mask) 

254 result["has_nans"] = result["nan_count"] > 0 

255 result_size: int = result["size"] # ty: ignore[invalid-assignment] 

256 if result_size > 0: 

257 result["nan_percent"] = (result["nan_count"] / result_size) * 100 

258 except: # noqa: E722 

259 pass 

260 

261 # If all values are NaN, return early 

262 if result["has_nans"] and result["nan_count"] == result["size"]: 

263 result["status"] = "all NaN" 

264 return result 

265 

266 # Calculate statistics 

267 try: 

268 if result["has_nans"]: 

269 result["min"] = float(np.nanmin(A_flat)) 

270 result["max"] = float(np.nanmax(A_flat)) 

271 result["mean"] = float(np.nanmean(A_flat)) 

272 result["std"] = float(np.nanstd(A_flat)) 

273 result["median"] = float(np.nanmedian(A_flat)) 

274 result["range"] = (result["min"], result["max"]) 

275 

276 # Remove NaNs for histogram 

277 # TYPING: nan mask will def be bound on this branch, idk why it thinks the operator is bad 

278 A_hist = A_flat[~nan_mask] # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable] 

279 else: 

280 result["min"] = float(np.min(A_flat)) 

281 result["max"] = float(np.max(A_flat)) 

282 result["mean"] = float(np.mean(A_flat)) 

283 result["std"] = float(np.std(A_flat)) 

284 result["median"] = float(np.median(A_flat)) 

285 result["range"] = (result["min"], result["max"]) 

286 

287 A_hist = A_flat 

288 

289 # Calculate histogram data for sparklines 

290 if A_hist.size > 0: 

291 try: 

292 # TODO: handle bool tensors correctly 

293 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. 

294 hist, bins = np.histogram(A_hist, bins=hist_bins) 

295 result["histogram"] = hist 

296 result["bins"] = bins 

297 except: # noqa: E722 

298 pass 

299 

300 result["status"] = "ok" 

301 except Exception as e: 

302 result["status"] = f"error: {str(e)}" 

303 

304 return result 

305 

306 

307SparklineFormat = Literal["unicode", "latex", "ascii"] 

308 

309 

310def generate_sparkline( 

311 histogram: np.ndarray, 

312 format: SparklineFormat = "unicode", 

313 log_y: Optional[bool] = None, 

314) -> tuple[str, bool]: 

315 """Generate a sparkline visualization of the histogram. 

316 

317 # Parameters: 

318 - `histogram : np.ndarray` 

319 Histogram data 

320 - `format : Literal["unicode", "latex", "ascii"]` 

321 Output format (defaults to `"unicode"`) 

322 - `log_y : bool|None` 

323 Whether to use logarithmic y-scale. `None` for automatic detection 

324 (defaults to `None`) 

325 

326 # Returns: 

327 - `tuple[str, bool]` 

328 Sparkline visualization and whether log scale was used 

329 """ 

330 if histogram is None or len(histogram) == 0: 

331 return "", False 

332 

333 # Get the appropriate character set 

334 chars: List[str] 

335 if format in SPARK_CHARS: 

336 chars = SPARK_CHARS[format] 

337 else: 

338 chars = SPARK_CHARS["ascii"] 

339 

340 # automatic detection of log_y 

341 if log_y is None: 

342 # we bin the histogram values to the number of levels in our sparkline characters 

343 hist_hist = np.histogram(histogram, bins=len(chars))[0] 

344 # if every bin except the smallest (first) and largest (last) is empty, 

345 # then we should use the log scale. if those bins are nonempty, keep the linear scale 

346 if hist_hist[1:-1].max() > 0: 

347 log_y = False 

348 else: 

349 log_y = True 

350 

351 # Handle log scale 

352 if log_y: 

353 # Add small value to avoid log(0) 

354 hist_data = np.log1p(histogram) 

355 else: 

356 hist_data = histogram 

357 

358 # Normalize to character set range 

359 if hist_data.max() > 0: 

360 normalized = hist_data / hist_data.max() * (len(chars) - 1) 

361 else: 

362 normalized = np.zeros_like(hist_data) 

363 

364 # Convert to characters 

365 spark = "" 

366 for val in normalized: 

367 idx = round(val) 

368 spark += chars[idx] 

369 

370 return spark, log_y 

371 

372 

373DEFAULT_SETTINGS: ArraySummarySettings = { 

374 "fmt": "unicode", 

375 "precision": 2, 

376 "stats": True, 

377 "shape": True, 

378 "dtype": True, 

379 "device": True, 

380 "requires_grad": True, 

381 "sparkline": False, 

382 "sparkline_bins": 5, 

383 "sparkline_logy": None, 

384 "colored": False, 

385 "as_list": False, 

386 "eq_char": "=", 

387} 

388 

389 

390def apply_color( 

391 text: str, color_key: str, colors: Dict[str, str], using_tex: bool 

392) -> str: 

393 if using_tex: 

394 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

395 else: 

396 return ( 

397 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text 

398 ) 

399 

400 

401def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: 

402 """Colorize dtype string with specific colors for torch and type names.""" 

403 

404 # Handle torch prefix 

405 type_part: str = dtype_str 

406 prefix_part: Optional[str] = None 

407 if "torch." in dtype_str: 

408 parts = dtype_str.split("torch.") 

409 if len(parts) == 2: 

410 prefix_part = apply_color("torch", "torch", colors, using_tex) 

411 type_part = parts[1] 

412 

413 # Handle type coloring 

414 color_key: str = "dtype" 

415 if "bool" in dtype_str.lower(): 

416 color_key = "dtype_bool" 

417 elif "int" in dtype_str.lower(): 

418 color_key = "dtype_int" 

419 elif "float" in dtype_str.lower(): 

420 color_key = "dtype_float" 

421 

422 type_colored: str = apply_color(type_part, color_key, colors, using_tex) 

423 

424 if prefix_part: 

425 return f"{prefix_part}.{type_colored}" 

426 else: 

427 return type_colored 

428 

429 

430def format_shape_colored( 

431 shape_val: Any, colors: Dict[str, str], using_tex: bool 

432) -> str: 

433 """Format shape with proper coloring for both 1D and multi-D arrays.""" 

434 

435 def apply_color(text: str, color_key: str) -> str: 

436 if using_tex: 

437 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

438 else: 

439 return ( 

440 f"{colors[color_key]}{text}{colors['reset']}" 

441 if colors[color_key] 

442 else text 

443 ) 

444 

445 if len(shape_val) == 1: 

446 # For 1D arrays, still color the dimension value 

447 return apply_color(str(shape_val[0]), "shape") 

448 else: 

449 # For multi-D arrays, color each dimension 

450 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")" 

451 

452 

453def format_device_colored( 

454 device_str: str, colors: Dict[str, str], using_tex: bool 

455) -> str: 

456 """Format device string with CUDA highlighting.""" 

457 

458 def apply_color(text: str, color_key: str) -> str: 

459 if using_tex: 

460 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

461 else: 

462 return ( 

463 f"{colors[color_key]}{text}{colors['reset']}" 

464 if colors[color_key] 

465 else text 

466 ) 

467 

468 if "cuda" in device_str.lower(): 

469 return apply_color(device_str, "device_cuda") 

470 else: 

471 return apply_color(device_str, "device") 

472 

473 

474class _UseDefaultType: 

475 pass 

476 

477 

478_USE_DEFAULT = _UseDefaultType() 

479 

480 

481@overload 

482def array_summary( 

483 array: Any, 

484 fmt: OutputFormat = ..., 

485 precision: int = ..., 

486 stats: bool = ..., 

487 shape: bool = ..., 

488 dtype: bool = ..., 

489 device: bool = ..., 

490 requires_grad: bool = ..., 

491 sparkline: bool = ..., 

492 sparkline_bins: int = ..., 

493 sparkline_logy: Optional[bool] = ..., 

494 colored: bool = ..., 

495 eq_char: str = ..., 

496 *, 

497 as_list: Literal[True], 

498) -> List[str]: ... 

499@overload 

500def array_summary( 

501 array: Any, 

502 fmt: OutputFormat = ..., 

503 precision: int = ..., 

504 stats: bool = ..., 

505 shape: bool = ..., 

506 dtype: bool = ..., 

507 device: bool = ..., 

508 requires_grad: bool = ..., 

509 sparkline: bool = ..., 

510 sparkline_bins: int = ..., 

511 sparkline_logy: Optional[bool] = ..., 

512 colored: bool = ..., 

513 eq_char: str = ..., 

514 as_list: Literal[False] = ..., 

515) -> str: ... 

516@overload 

517def array_summary( 

518 array: Any, 

519 fmt: OutputFormat = ..., 

520 precision: int = ..., 

521 stats: bool = ..., 

522 shape: bool = ..., 

523 dtype: bool = ..., 

524 device: bool = ..., 

525 requires_grad: bool = ..., 

526 sparkline: bool = ..., 

527 sparkline_bins: int = ..., 

528 sparkline_logy: Optional[bool] = ..., 

529 colored: bool = ..., 

530 eq_char: str = ..., 

531 as_list: bool = ..., 

532) -> Union[str, List[str]]: ... 

533def array_summary( 

534 array: Any, 

535 fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT, 

536 precision: Union[int, _UseDefaultType] = _USE_DEFAULT, 

537 stats: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

538 shape: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

539 dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

540 device: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

541 requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

542 sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

543 sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT, 

544 sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT, 

545 colored: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

546 eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT, 

547 as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT, 

548) -> Union[str, List[str]]: 

549 """Format array information into a readable summary. 

550 

551 # Parameters: 

552 - `array` 

553 array-like object (numpy array or torch tensor) 

554 - `precision : int` 

555 Decimal places (defaults to `2`) 

556 - `format : Literal["unicode", "latex", "ascii"]` 

557 Output format (defaults to `{default_fmt}`) 

558 - `stats : bool` 

559 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 

560 - `shape : bool` 

561 Whether to include shape info (defaults to `True`) 

562 - `dtype : bool` 

563 Whether to include dtype info (defaults to `True`) 

564 - `device : bool` 

565 Whether to include device info for torch tensors (defaults to `True`) 

566 - `requires_grad : bool` 

567 Whether to include requires_grad info for torch tensors (defaults to `True`) 

568 - `sparkline : bool` 

569 Whether to include a sparkline visualization (defaults to `False`) 

570 - `sparkline_width : int` 

571 Width of the sparkline (defaults to `20`) 

572 - `sparkline_logy : bool|None` 

573 Whether to use logarithmic y-scale for sparkline (defaults to `None`) 

574 - `colored : bool` 

575 Whether to add color to output (defaults to `False`) 

576 - `as_list : bool` 

577 Whether to return as list of strings instead of joined string (defaults to `False`) 

578 

579 # Returns: 

580 - `Union[str, List[str]]` 

581 Formatted statistical summary, either as string or list of strings 

582 """ 

583 if isinstance(fmt, _UseDefaultType): 

584 fmt = DEFAULT_SETTINGS["fmt"] 

585 if isinstance(precision, _UseDefaultType): 

586 precision = DEFAULT_SETTINGS["precision"] 

587 if isinstance(stats, _UseDefaultType): 

588 stats = DEFAULT_SETTINGS["stats"] 

589 if isinstance(shape, _UseDefaultType): 

590 shape = DEFAULT_SETTINGS["shape"] 

591 if isinstance(dtype, _UseDefaultType): 

592 dtype = DEFAULT_SETTINGS["dtype"] 

593 if isinstance(device, _UseDefaultType): 

594 device = DEFAULT_SETTINGS["device"] 

595 if isinstance(requires_grad, _UseDefaultType): 

596 requires_grad = DEFAULT_SETTINGS["requires_grad"] 

597 if isinstance(sparkline, _UseDefaultType): 

598 sparkline = DEFAULT_SETTINGS["sparkline"] 

599 if isinstance(sparkline_bins, _UseDefaultType): 

600 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 

601 if isinstance(sparkline_logy, _UseDefaultType): 

602 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 

603 if isinstance(colored, _UseDefaultType): 

604 colored = DEFAULT_SETTINGS["colored"] 

605 if isinstance(as_list, _UseDefaultType): 

606 as_list = DEFAULT_SETTINGS["as_list"] 

607 if isinstance(eq_char, _UseDefaultType): 

608 eq_char = DEFAULT_SETTINGS["eq_char"] 

609 

610 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 

611 result_parts: List[str] = [] 

612 using_tex: bool = fmt == "latex" 

613 

614 # Set color scheme based on format and colored flag 

615 colors: Dict[str, str] 

616 if colored: 

617 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 

618 else: 

619 colors = COLORS["none"] 

620 

621 # Get symbols for the current format 

622 symbols: Dict[str, str] = SYMBOLS[fmt] 

623 

624 # Helper function to colorize text 

625 def colorize(text: str, color_key: str) -> str: 

626 if using_tex: 

627 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 

628 else: 

629 return ( 

630 f"{colors[color_key]}{text}{colors['reset']}" 

631 if colors[color_key] 

632 else text 

633 ) 

634 

635 # Check if dtype is integer type 

636 dtype_str: str = array_data.get("dtype", "") 

637 is_int_dtype: bool = any( 

638 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] 

639 ) 

640 

641 # Format string for numbers 

642 float_fmt: str = f".{precision}f" 

643 

644 # Handle error status or empty array 

645 if ( 

646 array_data["status"] in ["empty array", "all NaN", "unknown"] 

647 or array_data["size"] == 0 

648 ): 

649 status = array_data["status"] 

650 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 

651 else: 

652 # Add NaN warning at the beginning if there are NaNs 

653 if array_data["has_nans"]: 

654 _percent: str = "\\%" if using_tex else "%" 

655 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 

656 result_parts.append(colorize(nan_str, "warning")) 

657 

658 # Statistics 

659 if stats: 

660 for stat_key in ["mean", "std", "median"]: 

661 if array_data[stat_key] is not None: 

662 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 

663 stat_colored: str = colorize(stat_str, stat_key) 

664 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 

665 

666 # Range (min, max) 

667 if array_data["range"] is not None: 

668 min_val, max_val = array_data["range"] 

669 if is_int_dtype: 

670 min_str: str = f"{int(min_val):d}" 

671 max_str: str = f"{int(max_val):d}" 

672 else: 

673 min_str = f"{min_val:{float_fmt}}" 

674 max_str = f"{max_val:{float_fmt}}" 

675 min_colored: str = colorize(min_str, "range") 

676 max_colored: str = colorize(max_str, "range") 

677 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 

678 result_parts.append(range_str) 

679 

680 # Add sparkline if requested 

681 if sparkline and array_data["histogram"] is not None: 

682 # this should return whether log_y is used or not and then we set the symbol accordingly 

683 spark, used_log = generate_sparkline( 

684 array_data["histogram"], 

685 format=fmt, 

686 log_y=sparkline_logy, 

687 ) 

688 if spark: 

689 spark_colored = colorize(spark, "sparkline") 

690 dist_symbol = ( 

691 symbols["distribution_log"] if used_log else symbols["distribution"] 

692 ) 

693 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") 

694 

695 # Add shape if requested 

696 if shape and array_data["shape"]: 

697 shape_val = array_data["shape"] 

698 shape_str = format_shape_colored(shape_val, colors, using_tex) 

699 result_parts.append(f"shape{eq_char}{shape_str}") 

700 

701 # Add dtype if requested 

702 if dtype and array_data["dtype"]: 

703 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) 

704 result_parts.append(f"dtype={dtype_colored}") 

705 

706 # Add device if requested and it's a tensor with device info 

707 if device and array_data["is_tensor"] and array_data["device"]: 

708 device_colored = format_device_colored(array_data["device"], colors, using_tex) 

709 result_parts.append(f"device{eq_char}{device_colored}") 

710 

711 # Add gradient info 

712 if requires_grad and array_data["is_tensor"]: 

713 bool_req_grad_symb: str = ( 

714 symbols["true"] if array_data["requires_grad"] else symbols["false"] 

715 ) 

716 result_parts.append( 

717 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 

718 ) 

719 

720 # Return as list if requested, otherwise join with spaces 

721 if as_list: 

722 return result_parts 

723 else: 

724 joinchar: str = r" \quad " if using_tex else " " 

725 return joinchar.join(result_parts)