Coverage for muutils/tensor_info.py: 88%
199 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 18:27 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 18:27 -0600
1import numpy as np
2from typing import Union, Any, Literal, List, Dict, overload
4# Global color definitions
5COLORS: Dict[str, Dict[str, str]] = {
6 "latex": {
7 "range": r"\textcolor{purple}",
8 "mean": r"\textcolor{teal}",
9 "std": r"\textcolor{orange}",
10 "median": r"\textcolor{green}",
11 "warning": r"\textcolor{red}",
12 "shape": r"\textcolor{magenta}",
13 "dtype": r"\textcolor{gray}",
14 "device": r"\textcolor{gray}",
15 "requires_grad": r"\textcolor{gray}",
16 "sparkline": r"\textcolor{blue}",
17 "reset": "",
18 },
19 "terminal": {
20 "range": "\033[35m", # purple
21 "mean": "\033[36m", # cyan/teal
22 "std": "\033[33m", # yellow/orange
23 "median": "\033[32m", # green
24 "warning": "\033[31m", # red
25 "shape": "\033[95m", # bright magenta
26 "dtype": "\033[90m", # gray
27 "device": "\033[90m", # gray
28 "requires_grad": "\033[90m", # gray
29 "sparkline": "\033[34m", # blue
30 "reset": "\033[0m",
31 },
32 "none": {
33 "range": "",
34 "mean": "",
35 "std": "",
36 "median": "",
37 "warning": "",
38 "shape": "",
39 "dtype": "",
40 "device": "",
41 "requires_grad": "",
42 "sparkline": "",
43 "reset": "",
44 },
45}
47OutputFormat = Literal["unicode", "latex", "ascii"]
49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
50 "latex": {
51 "range": r"\mathcal{R}",
52 "mean": r"\mu",
53 "std": r"\sigma",
54 "median": r"\tilde{x}",
55 "distribution": r"\mathbb{P}",
56 "nan_values": r"\text{NANvals}",
57 "warning": "!!!",
58 "requires_grad": r"\nabla",
59 "true": r"\checkmark",
60 "false": r"\times",
61 },
62 "unicode": {
63 "range": "R",
64 "mean": "μ",
65 "std": "σ",
66 "median": "x̃",
67 "distribution": "ℙ",
68 "nan_values": "NANvals",
69 "warning": "🚨",
70 "requires_grad": "∇",
71 "true": "✓",
72 "false": "✗",
73 },
74 "ascii": {
75 "range": "range",
76 "mean": "mean",
77 "std": "std",
78 "median": "med",
79 "distribution": "dist",
80 "nan_values": "NANvals",
81 "warning": "!!!",
82 "requires_grad": "requires_grad",
83 "true": "1",
84 "false": "0",
85 },
86}
87"Symbols for different formats"
89SPARK_CHARS: Dict[OutputFormat, List[str]] = {
90 "unicode": list(" ▁▂▃▄▅▆▇█"),
91 "ascii": list(" _.-~=#"),
92 "latex": list(" ▁▂▃▄▅▆▇█"),
93}
94"characters for sparklines in different formats"
97def array_info(
98 A: Any,
99 hist_bins: int = 5,
100) -> Dict[str, Any]:
101 """Extract statistical information from an array-like object.
103 # Parameters:
104 - `A : array-like`
105 Array to analyze (numpy array or torch tensor)
107 # Returns:
108 - `Dict[str, Any]`
109 Dictionary containing raw statistical information with numeric values
110 """
111 result: Dict[str, Any] = {
112 "is_tensor": None,
113 "device": None,
114 "requires_grad": None,
115 "shape": None,
116 "dtype": None,
117 "size": None,
118 "has_nans": None,
119 "nan_count": None,
120 "nan_percent": None,
121 "min": None,
122 "max": None,
123 "range": None,
124 "mean": None,
125 "std": None,
126 "median": None,
127 "histogram": None,
128 "bins": None,
129 "status": None,
130 }
132 # Check if it's a tensor by looking at its class name
133 # This avoids importing torch directly
134 A_type: str = type(A).__name__
135 result["is_tensor"] = A_type == "Tensor"
137 # Try to get device information if it's a tensor
138 if result["is_tensor"]:
139 try:
140 result["device"] = str(getattr(A, "device", None))
141 except: # noqa: E722
142 pass
144 # Convert to numpy array for calculations
145 try:
146 # For PyTorch tensors
147 if result["is_tensor"]:
148 # Check if tensor is on GPU
149 is_cuda: bool = False
150 try:
151 is_cuda = bool(getattr(A, "is_cuda", False))
152 except: # noqa: E722
153 pass
155 if is_cuda:
156 try:
157 # Try to get CPU tensor first
158 cpu_tensor = getattr(A, "cpu", lambda: A)()
159 except: # noqa: E722
160 A_np = np.array([])
161 else:
162 cpu_tensor = A
163 try:
164 # For CPU tensor, just detach and convert
165 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
166 A_np = getattr(detached, "numpy", lambda: np.array([]))()
167 except: # noqa: E722
168 A_np = np.array([])
169 else:
170 # For numpy arrays and other array-like objects
171 A_np = np.asarray(A)
172 except: # noqa: E722
173 A_np = np.array([])
175 # Get basic information
176 try:
177 result["shape"] = A_np.shape
178 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
179 result["size"] = A_np.size
180 result["requires_grad"] = getattr(A, "requires_grad", None)
181 except: # noqa: E722
182 pass
184 # If array is empty, return early
185 if result["size"] == 0:
186 result["status"] = "empty array"
187 return result
189 # Flatten array for statistics if it's multi-dimensional
190 try:
191 if len(A_np.shape) > 1:
192 A_flat = A_np.flatten()
193 else:
194 A_flat = A_np
195 except: # noqa: E722
196 A_flat = A_np
198 # Check for NaN values
199 try:
200 nan_mask = np.isnan(A_flat)
201 result["nan_count"] = np.sum(nan_mask)
202 result["has_nans"] = result["nan_count"] > 0
203 if result["size"] > 0:
204 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
205 except: # noqa: E722
206 pass
208 # If all values are NaN, return early
209 if result["has_nans"] and result["nan_count"] == result["size"]:
210 result["status"] = "all NaN"
211 return result
213 # Calculate statistics
214 try:
215 if result["has_nans"]:
216 result["min"] = float(np.nanmin(A_flat))
217 result["max"] = float(np.nanmax(A_flat))
218 result["mean"] = float(np.nanmean(A_flat))
219 result["std"] = float(np.nanstd(A_flat))
220 result["median"] = float(np.nanmedian(A_flat))
221 result["range"] = (result["min"], result["max"])
223 # Remove NaNs for histogram
224 A_hist = A_flat[~nan_mask]
225 else:
226 result["min"] = float(np.min(A_flat))
227 result["max"] = float(np.max(A_flat))
228 result["mean"] = float(np.mean(A_flat))
229 result["std"] = float(np.std(A_flat))
230 result["median"] = float(np.median(A_flat))
231 result["range"] = (result["min"], result["max"])
233 A_hist = A_flat
235 # Calculate histogram data for sparklines
236 if A_hist.size > 0:
237 try:
238 hist, bins = np.histogram(A_hist, bins=hist_bins)
239 result["histogram"] = hist
240 result["bins"] = bins
241 except: # noqa: E722
242 pass
244 result["status"] = "ok"
245 except Exception as e:
246 result["status"] = f"error: {str(e)}"
248 return result
251def generate_sparkline(
252 histogram: np.ndarray,
253 format: Literal["unicode", "latex", "ascii"] = "unicode",
254 log_y: bool = False,
255) -> str:
256 """Generate a sparkline visualization of the histogram.
258 # Parameters:
259 - `histogram : np.ndarray`
260 Histogram data
261 - `format : Literal["unicode", "latex", "ascii"]`
262 Output format (defaults to `"unicode"`)
263 - `log_y : bool`
264 Whether to use logarithmic y-scale (defaults to `False`)
266 # Returns:
267 - `str`
268 Sparkline visualization
269 """
270 if histogram is None or len(histogram) == 0:
271 return ""
273 # Get the appropriate character set
274 if format in SPARK_CHARS:
275 chars = SPARK_CHARS[format]
276 else:
277 chars = SPARK_CHARS["ascii"]
279 # Handle log scale
280 if log_y:
281 # Add small value to avoid log(0)
282 hist_data = np.log1p(histogram)
283 else:
284 hist_data = histogram
286 # Normalize to character set range
287 if hist_data.max() > 0:
288 normalized = hist_data / hist_data.max() * (len(chars) - 1)
289 else:
290 normalized = np.zeros_like(hist_data)
292 # Convert to characters
293 spark = ""
294 for val in normalized:
295 idx = int(val)
296 spark += chars[idx]
298 return spark
301DEFAULT_SETTINGS: Dict[str, Any] = dict(
302 fmt="unicode",
303 precision=2,
304 stats=True,
305 shape=True,
306 dtype=True,
307 device=True,
308 requires_grad=True,
309 sparkline=False,
310 sparkline_bins=5,
311 sparkline_logy=False,
312 colored=False,
313 as_list=False,
314 eq_char="=",
315)
318class _UseDefaultType:
319 pass
322_USE_DEFAULT = _UseDefaultType()
325@overload
326def array_summary(
327 as_list: Literal[True],
328 **kwargs,
329) -> List[str]: ...
330@overload
331def array_summary(
332 as_list: Literal[False],
333 **kwargs,
334) -> str: ...
335def array_summary( # type: ignore[misc]
336 array,
337 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment]
338 precision: int = _USE_DEFAULT, # type: ignore[assignment]
339 stats: bool = _USE_DEFAULT, # type: ignore[assignment]
340 shape: bool = _USE_DEFAULT, # type: ignore[assignment]
341 dtype: bool = _USE_DEFAULT, # type: ignore[assignment]
342 device: bool = _USE_DEFAULT, # type: ignore[assignment]
343 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment]
344 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment]
345 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment]
346 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment]
347 colored: bool = _USE_DEFAULT, # type: ignore[assignment]
348 eq_char: str = _USE_DEFAULT, # type: ignore[assignment]
349 as_list: bool = _USE_DEFAULT, # type: ignore[assignment]
350) -> Union[str, List[str]]:
351 """Format array information into a readable summary.
353 # Parameters:
354 - `array`
355 array-like object (numpy array or torch tensor)
356 - `precision : int`
357 Decimal places (defaults to `2`)
358 - `format : Literal["unicode", "latex", "ascii"]`
359 Output format (defaults to `{default_fmt}`)
360 - `stats : bool`
361 Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
362 - `shape : bool`
363 Whether to include shape info (defaults to `True`)
364 - `dtype : bool`
365 Whether to include dtype info (defaults to `True`)
366 - `device : bool`
367 Whether to include device info for torch tensors (defaults to `True`)
368 - `requires_grad : bool`
369 Whether to include requires_grad info for torch tensors (defaults to `True`)
370 - `sparkline : bool`
371 Whether to include a sparkline visualization (defaults to `False`)
372 - `sparkline_width : int`
373 Width of the sparkline (defaults to `20`)
374 - `sparkline_logy : bool`
375 Whether to use logarithmic y-scale for sparkline (defaults to `False`)
376 - `colored : bool`
377 Whether to add color to output (defaults to `False`)
378 - `as_list : bool`
379 Whether to return as list of strings instead of joined string (defaults to `False`)
381 # Returns:
382 - `Union[str, List[str]]`
383 Formatted statistical summary, either as string or list of strings
384 """
385 if fmt is _USE_DEFAULT:
386 fmt = DEFAULT_SETTINGS["fmt"]
387 if precision is _USE_DEFAULT:
388 precision = DEFAULT_SETTINGS["precision"]
389 if stats is _USE_DEFAULT:
390 stats = DEFAULT_SETTINGS["stats"]
391 if shape is _USE_DEFAULT:
392 shape = DEFAULT_SETTINGS["shape"]
393 if dtype is _USE_DEFAULT:
394 dtype = DEFAULT_SETTINGS["dtype"]
395 if device is _USE_DEFAULT:
396 device = DEFAULT_SETTINGS["device"]
397 if requires_grad is _USE_DEFAULT:
398 requires_grad = DEFAULT_SETTINGS["requires_grad"]
399 if sparkline is _USE_DEFAULT:
400 sparkline = DEFAULT_SETTINGS["sparkline"]
401 if sparkline_bins is _USE_DEFAULT:
402 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
403 if sparkline_logy is _USE_DEFAULT:
404 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
405 if colored is _USE_DEFAULT:
406 colored = DEFAULT_SETTINGS["colored"]
407 if as_list is _USE_DEFAULT:
408 as_list = DEFAULT_SETTINGS["as_list"]
409 if eq_char is _USE_DEFAULT:
410 eq_char = DEFAULT_SETTINGS["eq_char"]
412 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
413 result_parts: List[str] = []
414 using_tex: bool = fmt == "latex"
416 # Set color scheme based on format and colored flag
417 colors: Dict[str, str]
418 if colored:
419 colors = COLORS["latex"] if using_tex else COLORS["terminal"]
420 else:
421 colors = COLORS["none"]
423 # Get symbols for the current format
424 symbols: Dict[str, str] = SYMBOLS[fmt]
426 # Helper function to colorize text
427 def colorize(text: str, color_key: str) -> str:
428 if using_tex:
429 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text
430 else:
431 return (
432 f"{colors[color_key]}{text}{colors['reset']}"
433 if colors[color_key]
434 else text
435 )
437 # Format string for numbers
438 float_fmt: str = f".{precision}f"
440 # Handle error status or empty array
441 if (
442 array_data["status"] in ["empty array", "all NaN", "unknown"]
443 or array_data["size"] == 0
444 ):
445 status = array_data["status"]
446 result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
447 else:
448 # Add NaN warning at the beginning if there are NaNs
449 if array_data["has_nans"]:
450 _percent: str = "\\%" if using_tex else "%"
451 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
452 result_parts.append(colorize(nan_str, "warning"))
454 # Statistics
455 if stats:
456 for stat_key in ["mean", "std", "median"]:
457 if array_data[stat_key] is not None:
458 stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
459 stat_colored: str = colorize(stat_str, stat_key)
460 result_parts.append(f"{symbols[stat_key]}={stat_colored}")
462 # Range (min, max)
463 if array_data["range"] is not None:
464 min_val, max_val = array_data["range"]
465 min_str: str = f"{min_val:{float_fmt}}"
466 max_str: str = f"{max_val:{float_fmt}}"
467 min_colored: str = colorize(min_str, "range")
468 max_colored: str = colorize(max_str, "range")
469 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
470 result_parts.append(range_str)
472 # Add sparkline if requested
473 if sparkline and array_data["histogram"] is not None:
474 spark = generate_sparkline(
475 array_data["histogram"], format=fmt, log_y=sparkline_logy
476 )
477 if spark:
478 spark_colored = colorize(spark, "sparkline")
479 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|")
481 # Add shape if requested
482 if shape and array_data["shape"]:
483 shape_val = array_data["shape"]
484 if len(shape_val) == 1:
485 shape_str = str(shape_val[0])
486 else:
487 shape_str = (
488 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")"
489 )
490 result_parts.append(f"shape{eq_char}{shape_str}")
492 # Add dtype if requested
493 if dtype and array_data["dtype"]:
494 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype"))
496 # Add device if requested and it's a tensor with device info
497 if device and array_data["is_tensor"] and array_data["device"]:
498 result_parts.append(
499 colorize(f"device{eq_char}{array_data['device']}", "device")
500 )
502 # Add gradient info
503 if requires_grad and array_data["is_tensor"]:
504 bool_req_grad_symb: str = (
505 symbols["true"] if array_data["requires_grad"] else symbols["false"]
506 )
507 result_parts.append(
508 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
509 )
511 # Return as list if requested, otherwise join with spaces
512 if as_list:
513 return result_parts
514 else:
515 joinchar: str = r" \quad " if using_tex else " "
516 return joinchar.join(result_parts)