Coverage for muutils/tensor_info.py: 89%
246 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
1"get metadata about a tensor, mostly for `muutils.dbg`"
3from __future__ import annotations
5import numpy as np
6from typing import Union, Any, Literal, List, Dict, overload, Optional
8# Global color definitions
9COLORS: Dict[str, Dict[str, str]] = {
10 "latex": {
11 "range": r"\textcolor{purple}",
12 "mean": r"\textcolor{teal}",
13 "std": r"\textcolor{orange}",
14 "median": r"\textcolor{green}",
15 "warning": r"\textcolor{red}",
16 "shape": r"\textcolor{magenta}",
17 "dtype": r"\textcolor{gray}",
18 "device": r"\textcolor{gray}",
19 "requires_grad": r"\textcolor{gray}",
20 "sparkline": r"\textcolor{blue}",
21 "torch": r"\textcolor{orange}",
22 "dtype_bool": r"\textcolor{gray}",
23 "dtype_int": r"\textcolor{blue}",
24 "dtype_float": r"\textcolor{red!70}", # 70% red intensity
25 "dtype_str": r"\textcolor{red}",
26 "device_cuda": r"\textcolor{green}",
27 "reset": "",
28 },
29 "terminal": {
30 "range": "\033[35m", # purple
31 "mean": "\033[36m", # cyan/teal
32 "std": "\033[33m", # yellow/orange
33 "median": "\033[32m", # green
34 "warning": "\033[31m", # red
35 "shape": "\033[95m", # bright magenta
36 "dtype": "\033[90m", # gray
37 "device": "\033[90m", # gray
38 "requires_grad": "\033[90m", # gray
39 "sparkline": "\033[34m", # blue
40 "torch": "\033[38;5;208m", # bright orange
41 "dtype_bool": "\033[38;5;245m", # medium grey
42 "dtype_int": "\033[38;5;39m", # bright blue
43 "dtype_float": "\033[38;5;167m", # softer red/coral
44 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green
45 "reset": "\033[0m",
46 },
47 "none": {
48 "range": "",
49 "mean": "",
50 "std": "",
51 "median": "",
52 "warning": "",
53 "shape": "",
54 "dtype": "",
55 "device": "",
56 "requires_grad": "",
57 "sparkline": "",
58 "torch": "",
59 "dtype_bool": "",
60 "dtype_int": "",
61 "dtype_float": "",
62 "dtype_str": "",
63 "device_cuda": "",
64 "reset": "",
65 },
66}
68OutputFormat = Literal["unicode", "latex", "ascii"]
70SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
71 "latex": {
72 "range": r"\mathcal{R}",
73 "mean": r"\mu",
74 "std": r"\sigma",
75 "median": r"\tilde{x}",
76 "distribution": r"\mathbb{P}",
77 "distribution_log": r"\mathbb{P}_L",
78 "nan_values": r"\text{NANvals}",
79 "warning": "!!!",
80 "requires_grad": r"\nabla",
81 "true": r"\checkmark",
82 "false": r"\times",
83 },
84 "unicode": {
85 "range": "R",
86 "mean": "μ",
87 "std": "σ",
88 "median": "x̃",
89 "distribution": "ℙ",
90 "distribution_log": "ℙ˪",
91 "nan_values": "NANvals",
92 "warning": "🚨",
93 "requires_grad": "∇",
94 "true": "✓",
95 "false": "✗",
96 },
97 "ascii": {
98 "range": "range",
99 "mean": "mean",
100 "std": "std",
101 "median": "med",
102 "distribution": "dist",
103 "distribution_log": "dist_log",
104 "nan_values": "NANvals",
105 "warning": "!!!",
106 "requires_grad": "requires_grad",
107 "true": "1",
108 "false": "0",
109 },
110}
111"Symbols for different formats"
113SPARK_CHARS: Dict[OutputFormat, List[str]] = {
114 "unicode": list(" ▁▂▃▄▅▆▇█"),
115 "ascii": list(" _.-~=#"),
116 "latex": list(" ▁▂▃▄▅▆▇█"),
117}
118"characters for sparklines in different formats"
121def array_info(
122 A: Any,
123 hist_bins: int = 5,
124) -> Dict[str, Any]:
125 """Extract statistical information from an array-like object.
127 # Parameters:
128 - `A : array-like`
129 Array to analyze (numpy array or torch tensor)
131 # Returns:
132 - `Dict[str, Any]`
133 Dictionary containing raw statistical information with numeric values
134 """
135 result: Dict[str, Any] = {
136 "is_tensor": None,
137 "device": None,
138 "requires_grad": None,
139 "shape": None,
140 "dtype": None,
141 "size": None,
142 "has_nans": None,
143 "nan_count": None,
144 "nan_percent": None,
145 "min": None,
146 "max": None,
147 "range": None,
148 "mean": None,
149 "std": None,
150 "median": None,
151 "histogram": None,
152 "bins": None,
153 "status": None,
154 }
156 # Check if it's a tensor by looking at its class name
157 # This avoids importing torch directly
158 A_type: str = type(A).__name__
159 result["is_tensor"] = A_type == "Tensor"
161 # Try to get device information if it's a tensor
162 if result["is_tensor"]:
163 try:
164 result["device"] = str(getattr(A, "device", None))
165 except: # noqa: E722
166 pass
168 # Convert to numpy array for calculations
169 try:
170 # For PyTorch tensors
171 if result["is_tensor"]:
172 # Check if tensor is on GPU
173 is_cuda: bool = False
174 try:
175 is_cuda = bool(getattr(A, "is_cuda", False))
176 except: # noqa: E722
177 pass
179 if is_cuda:
180 try:
181 # Try to get CPU tensor first
182 cpu_tensor = getattr(A, "cpu", lambda: A)()
183 except: # noqa: E722
184 A_np = np.array([])
185 else:
186 cpu_tensor = A
187 try:
188 # For CPU tensor, just detach and convert
189 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
190 A_np = getattr(detached, "numpy", lambda: np.array([]))()
191 except: # noqa: E722
192 A_np = np.array([])
193 else:
194 # For numpy arrays and other array-like objects
195 A_np = np.asarray(A)
196 except: # noqa: E722
197 A_np = np.array([])
199 # Get basic information
200 try:
201 result["shape"] = A_np.shape
202 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
203 result["size"] = A_np.size
204 result["requires_grad"] = getattr(A, "requires_grad", None)
205 except: # noqa: E722
206 pass
208 # If array is empty, return early
209 if result["size"] == 0:
210 result["status"] = "empty array"
211 return result
213 # Flatten array for statistics if it's multi-dimensional
214 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225
215 try:
216 if len(A_np.shape) > 1:
217 A_flat = A_np.flatten() # type: ignore[assignment]
218 else:
219 A_flat = A_np # type: ignore[assignment]
220 except: # noqa: E722
221 A_flat = A_np # type: ignore[assignment]
223 # Check for NaN values
224 try:
225 nan_mask = np.isnan(A_flat)
226 result["nan_count"] = np.sum(nan_mask)
227 result["has_nans"] = result["nan_count"] > 0
228 if result["size"] > 0:
229 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
230 except: # noqa: E722
231 pass
233 # If all values are NaN, return early
234 if result["has_nans"] and result["nan_count"] == result["size"]:
235 result["status"] = "all NaN"
236 return result
238 # Calculate statistics
239 try:
240 if result["has_nans"]:
241 result["min"] = float(np.nanmin(A_flat))
242 result["max"] = float(np.nanmax(A_flat))
243 result["mean"] = float(np.nanmean(A_flat))
244 result["std"] = float(np.nanstd(A_flat))
245 result["median"] = float(np.nanmedian(A_flat))
246 result["range"] = (result["min"], result["max"])
248 # Remove NaNs for histogram
249 A_hist = A_flat[~nan_mask]
250 else:
251 result["min"] = float(np.min(A_flat))
252 result["max"] = float(np.max(A_flat))
253 result["mean"] = float(np.mean(A_flat))
254 result["std"] = float(np.std(A_flat))
255 result["median"] = float(np.median(A_flat))
256 result["range"] = (result["min"], result["max"])
258 A_hist = A_flat
260 # Calculate histogram data for sparklines
261 if A_hist.size > 0:
262 try:
263 # TODO: handle bool tensors correctly
264 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
265 hist, bins = np.histogram(A_hist, bins=hist_bins)
266 result["histogram"] = hist
267 result["bins"] = bins
268 except: # noqa: E722
269 pass
271 result["status"] = "ok"
272 except Exception as e:
273 result["status"] = f"error: {str(e)}"
275 return result
278def generate_sparkline(
279 histogram: np.ndarray,
280 format: Literal["unicode", "latex", "ascii"] = "unicode",
281 log_y: Optional[bool] = None,
282) -> tuple[str, bool]:
283 """Generate a sparkline visualization of the histogram.
285 # Parameters:
286 - `histogram : np.ndarray`
287 Histogram data
288 - `format : Literal["unicode", "latex", "ascii"]`
289 Output format (defaults to `"unicode"`)
290 - `log_y : bool|None`
291 Whether to use logarithmic y-scale. `None` for automatic detection
292 (defaults to `None`)
294 # Returns:
295 - `tuple[str, bool]`
296 Sparkline visualization and whether log scale was used
297 """
298 if histogram is None or len(histogram) == 0:
299 return "", False
301 # Get the appropriate character set
302 chars: List[str]
303 if format in SPARK_CHARS:
304 chars = SPARK_CHARS[format]
305 else:
306 chars = SPARK_CHARS["ascii"]
308 # automatic detection of log_y
309 if log_y is None:
310 # we bin the histogram values to the number of levels in our sparkline characters
311 hist_hist = np.histogram(histogram, bins=len(chars))[0]
312 # if every bin except the smallest (first) and largest (last) is empty,
313 # then we should use the log scale. if those bins are nonempty, keep the linear scale
314 if hist_hist[1:-1].max() > 0:
315 log_y = False
316 else:
317 log_y = True
319 # Handle log scale
320 if log_y:
321 # Add small value to avoid log(0)
322 hist_data = np.log1p(histogram)
323 else:
324 hist_data = histogram
326 # Normalize to character set range
327 if hist_data.max() > 0:
328 normalized = hist_data / hist_data.max() * (len(chars) - 1)
329 else:
330 normalized = np.zeros_like(hist_data)
332 # Convert to characters
333 spark = ""
334 for val in normalized:
335 idx = round(val)
336 spark += chars[idx]
338 return spark, log_y
341DEFAULT_SETTINGS: Dict[str, Any] = dict(
342 fmt="unicode",
343 precision=2,
344 stats=True,
345 shape=True,
346 dtype=True,
347 device=True,
348 requires_grad=True,
349 sparkline=False,
350 sparkline_bins=5,
351 sparkline_logy=None,
352 colored=False,
353 as_list=False,
354 eq_char="=",
355)
358def apply_color(
359 text: str, color_key: str, colors: Dict[str, str], using_tex: bool
360) -> str:
361 if using_tex:
362 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
363 else:
364 return (
365 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
366 )
369def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
370 """Colorize dtype string with specific colors for torch and type names."""
372 # Handle torch prefix
373 type_part: str = dtype_str
374 prefix_part: Optional[str] = None
375 if "torch." in dtype_str:
376 parts = dtype_str.split("torch.")
377 if len(parts) == 2:
378 prefix_part = apply_color("torch", "torch", colors, using_tex)
379 type_part = parts[1]
381 # Handle type coloring
382 color_key: str = "dtype"
383 if "bool" in dtype_str.lower():
384 color_key = "dtype_bool"
385 elif "int" in dtype_str.lower():
386 color_key = "dtype_int"
387 elif "float" in dtype_str.lower():
388 color_key = "dtype_float"
390 type_colored: str = apply_color(type_part, color_key, colors, using_tex)
392 if prefix_part:
393 return f"{prefix_part}.{type_colored}"
394 else:
395 return type_colored
398def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str:
399 """Format shape with proper coloring for both 1D and multi-D arrays."""
401 def apply_color(text: str, color_key: str) -> str:
402 if using_tex:
403 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
404 else:
405 return (
406 f"{colors[color_key]}{text}{colors['reset']}"
407 if colors[color_key]
408 else text
409 )
411 if len(shape_val) == 1:
412 # For 1D arrays, still color the dimension value
413 return apply_color(str(shape_val[0]), "shape")
414 else:
415 # For multi-D arrays, color each dimension
416 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"
419def format_device_colored(
420 device_str: str, colors: Dict[str, str], using_tex: bool
421) -> str:
422 """Format device string with CUDA highlighting."""
424 def apply_color(text: str, color_key: str) -> str:
425 if using_tex:
426 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
427 else:
428 return (
429 f"{colors[color_key]}{text}{colors['reset']}"
430 if colors[color_key]
431 else text
432 )
434 if "cuda" in device_str.lower():
435 return apply_color(device_str, "device_cuda")
436 else:
437 return apply_color(device_str, "device")
440class _UseDefaultType:
441 pass
444_USE_DEFAULT = _UseDefaultType()
447@overload
448def array_summary(
449 array: Any,
450 as_list: Literal[True],
451 **kwargs,
452) -> List[str]: ...
453@overload
454def array_summary(
455 array: Any,
456 as_list: Literal[False],
457 **kwargs,
458) -> str: ...
459def array_summary( # type: ignore[misc]
460 array,
461 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment]
462 precision: int = _USE_DEFAULT, # type: ignore[assignment]
463 stats: bool = _USE_DEFAULT, # type: ignore[assignment]
464 shape: bool = _USE_DEFAULT, # type: ignore[assignment]
465 dtype: bool = _USE_DEFAULT, # type: ignore[assignment]
466 device: bool = _USE_DEFAULT, # type: ignore[assignment]
467 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment]
468 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment]
469 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment]
470 sparkline_logy: Optional[bool] = _USE_DEFAULT, # type: ignore[assignment]
471 colored: bool = _USE_DEFAULT, # type: ignore[assignment]
472 eq_char: str = _USE_DEFAULT, # type: ignore[assignment]
473 as_list: bool = _USE_DEFAULT, # type: ignore[assignment]
474) -> Union[str, List[str]]:
475 """Format array information into a readable summary.
477 # Parameters:
478 - `array`
479 array-like object (numpy array or torch tensor)
480 - `precision : int`
481 Decimal places (defaults to `2`)
482 - `format : Literal["unicode", "latex", "ascii"]`
483 Output format (defaults to `{default_fmt}`)
484 - `stats : bool`
485 Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
486 - `shape : bool`
487 Whether to include shape info (defaults to `True`)
488 - `dtype : bool`
489 Whether to include dtype info (defaults to `True`)
490 - `device : bool`
491 Whether to include device info for torch tensors (defaults to `True`)
492 - `requires_grad : bool`
493 Whether to include requires_grad info for torch tensors (defaults to `True`)
494 - `sparkline : bool`
495 Whether to include a sparkline visualization (defaults to `False`)
496 - `sparkline_width : int`
497 Width of the sparkline (defaults to `20`)
498 - `sparkline_logy : bool|None`
499 Whether to use logarithmic y-scale for sparkline (defaults to `None`)
500 - `colored : bool`
501 Whether to add color to output (defaults to `False`)
502 - `as_list : bool`
503 Whether to return as list of strings instead of joined string (defaults to `False`)
505 # Returns:
506 - `Union[str, List[str]]`
507 Formatted statistical summary, either as string or list of strings
508 """
509 if fmt is _USE_DEFAULT:
510 fmt = DEFAULT_SETTINGS["fmt"]
511 if precision is _USE_DEFAULT:
512 precision = DEFAULT_SETTINGS["precision"]
513 if stats is _USE_DEFAULT:
514 stats = DEFAULT_SETTINGS["stats"]
515 if shape is _USE_DEFAULT:
516 shape = DEFAULT_SETTINGS["shape"]
517 if dtype is _USE_DEFAULT:
518 dtype = DEFAULT_SETTINGS["dtype"]
519 if device is _USE_DEFAULT:
520 device = DEFAULT_SETTINGS["device"]
521 if requires_grad is _USE_DEFAULT:
522 requires_grad = DEFAULT_SETTINGS["requires_grad"]
523 if sparkline is _USE_DEFAULT:
524 sparkline = DEFAULT_SETTINGS["sparkline"]
525 if sparkline_bins is _USE_DEFAULT:
526 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
527 if sparkline_logy is _USE_DEFAULT:
528 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
529 if colored is _USE_DEFAULT:
530 colored = DEFAULT_SETTINGS["colored"]
531 if as_list is _USE_DEFAULT:
532 as_list = DEFAULT_SETTINGS["as_list"]
533 if eq_char is _USE_DEFAULT:
534 eq_char = DEFAULT_SETTINGS["eq_char"]
536 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
537 result_parts: List[str] = []
538 using_tex: bool = fmt == "latex"
540 # Set color scheme based on format and colored flag
541 colors: Dict[str, str]
542 if colored:
543 colors = COLORS["latex"] if using_tex else COLORS["terminal"]
544 else:
545 colors = COLORS["none"]
547 # Get symbols for the current format
548 symbols: Dict[str, str] = SYMBOLS[fmt]
550 # Helper function to colorize text
551 def colorize(text: str, color_key: str) -> str:
552 if using_tex:
553 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
554 else:
555 return (
556 f"{colors[color_key]}{text}{colors['reset']}"
557 if colors[color_key]
558 else text
559 )
561 # Check if dtype is integer type
562 dtype_str: str = array_data.get("dtype", "")
563 is_int_dtype: bool = any(
564 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
565 )
567 # Format string for numbers
568 float_fmt: str = f".{precision}f"
570 # Handle error status or empty array
571 if (
572 array_data["status"] in ["empty array", "all NaN", "unknown"]
573 or array_data["size"] == 0
574 ):
575 status = array_data["status"]
576 result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
577 else:
578 # Add NaN warning at the beginning if there are NaNs
579 if array_data["has_nans"]:
580 _percent: str = "\\%" if using_tex else "%"
581 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
582 result_parts.append(colorize(nan_str, "warning"))
584 # Statistics
585 if stats:
586 for stat_key in ["mean", "std", "median"]:
587 if array_data[stat_key] is not None:
588 stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
589 stat_colored: str = colorize(stat_str, stat_key)
590 result_parts.append(f"{symbols[stat_key]}={stat_colored}")
592 # Range (min, max)
593 if array_data["range"] is not None:
594 min_val, max_val = array_data["range"]
595 if is_int_dtype:
596 min_str: str = f"{int(min_val):d}"
597 max_str: str = f"{int(max_val):d}"
598 else:
599 min_str = f"{min_val:{float_fmt}}"
600 max_str = f"{max_val:{float_fmt}}"
601 min_colored: str = colorize(min_str, "range")
602 max_colored: str = colorize(max_str, "range")
603 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
604 result_parts.append(range_str)
606 # Add sparkline if requested
607 if sparkline and array_data["histogram"] is not None:
608 # this should return whether log_y is used or not and then we set the symbol accordingly
609 spark, used_log = generate_sparkline(
610 array_data["histogram"],
611 format=fmt,
612 log_y=sparkline_logy,
613 )
614 if spark:
615 spark_colored = colorize(spark, "sparkline")
616 dist_symbol = (
617 symbols["distribution_log"] if used_log else symbols["distribution"]
618 )
619 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")
621 # Add shape if requested
622 if shape and array_data["shape"]:
623 shape_val = array_data["shape"]
624 shape_str = format_shape_colored(shape_val, colors, using_tex)
625 result_parts.append(f"shape{eq_char}{shape_str}")
627 # Add dtype if requested
628 if dtype and array_data["dtype"]:
629 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
630 result_parts.append(f"dtype={dtype_colored}")
632 # Add device if requested and it's a tensor with device info
633 if device and array_data["is_tensor"] and array_data["device"]:
634 device_colored = format_device_colored(array_data["device"], colors, using_tex)
635 result_parts.append(f"device{eq_char}{device_colored}")
637 # Add gradient info
638 if requires_grad and array_data["is_tensor"]:
639 bool_req_grad_symb: str = (
640 symbols["true"] if array_data["requires_grad"] else symbols["false"]
641 )
642 result_parts.append(
643 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
644 )
646 # Return as list if requested, otherwise join with spaces
647 if as_list:
648 return result_parts
649 else:
650 joinchar: str = r" \quad " if using_tex else " "
651 return joinchar.join(result_parts)