Coverage for muutils/tensor_info.py: 90%
250 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-07 20:16 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-07 20:16 -0700
1"get metadata about a tensor, mostly for `muutils.dbg`"
3from __future__ import annotations
5import numpy as np
6from typing import Union, Any, Literal, List, Dict, overload, Optional
8# Global color definitions
9COLORS: Dict[str, Dict[str, str]] = {
10 "latex": {
11 "range": r"\textcolor{purple}",
12 "mean": r"\textcolor{teal}",
13 "std": r"\textcolor{orange}",
14 "median": r"\textcolor{green}",
15 "warning": r"\textcolor{red}",
16 "shape": r"\textcolor{magenta}",
17 "dtype": r"\textcolor{gray}",
18 "device": r"\textcolor{gray}",
19 "requires_grad": r"\textcolor{gray}",
20 "sparkline": r"\textcolor{blue}",
21 "torch": r"\textcolor{orange}",
22 "dtype_bool": r"\textcolor{gray}",
23 "dtype_int": r"\textcolor{blue}",
24 "dtype_float": r"\textcolor{red!70}", # 70% red intensity
25 "dtype_str": r"\textcolor{red}",
26 "device_cuda": r"\textcolor{green}",
27 "reset": "",
28 },
29 "terminal": {
30 "range": "\033[35m", # purple
31 "mean": "\033[36m", # cyan/teal
32 "std": "\033[33m", # yellow/orange
33 "median": "\033[32m", # green
34 "warning": "\033[31m", # red
35 "shape": "\033[95m", # bright magenta
36 "dtype": "\033[90m", # gray
37 "device": "\033[90m", # gray
38 "requires_grad": "\033[90m", # gray
39 "sparkline": "\033[34m", # blue
40 "torch": "\033[38;5;208m", # bright orange
41 "dtype_bool": "\033[38;5;245m", # medium grey
42 "dtype_int": "\033[38;5;39m", # bright blue
43 "dtype_float": "\033[38;5;167m", # softer red/coral
44 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green
45 "reset": "\033[0m",
46 },
47 "none": {
48 "range": "",
49 "mean": "",
50 "std": "",
51 "median": "",
52 "warning": "",
53 "shape": "",
54 "dtype": "",
55 "device": "",
56 "requires_grad": "",
57 "sparkline": "",
58 "torch": "",
59 "dtype_bool": "",
60 "dtype_int": "",
61 "dtype_float": "",
62 "dtype_str": "",
63 "device_cuda": "",
64 "reset": "",
65 },
66}
68OutputFormat = Literal["unicode", "latex", "ascii"]
70SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
71 "latex": {
72 "range": r"\mathcal{R}",
73 "mean": r"\mu",
74 "std": r"\sigma",
75 "median": r"\tilde{x}",
76 "distribution": r"\mathbb{P}",
77 "distribution_log": r"\mathbb{P}_L",
78 "nan_values": r"\text{NANvals}",
79 "warning": "!!!",
80 "requires_grad": r"\nabla",
81 "true": r"\checkmark",
82 "false": r"\times",
83 },
84 "unicode": {
85 "range": "R",
86 "mean": "μ",
87 "std": "σ",
88 "median": "x̃",
89 "distribution": "ℙ",
90 "distribution_log": "ℙ˪",
91 "nan_values": "NANvals",
92 "warning": "🚨",
93 "requires_grad": "∇",
94 "true": "✓",
95 "false": "✗",
96 },
97 "ascii": {
98 "range": "range",
99 "mean": "mean",
100 "std": "std",
101 "median": "med",
102 "distribution": "dist",
103 "distribution_log": "dist_log",
104 "nan_values": "NANvals",
105 "warning": "!!!",
106 "requires_grad": "requires_grad",
107 "true": "1",
108 "false": "0",
109 },
110}
111"Symbols for different formats"
113SPARK_CHARS: Dict[OutputFormat, List[str]] = {
114 "unicode": list(" ▁▂▃▄▅▆▇█"),
115 "ascii": list(" _.-~=#"),
116 "latex": list(" ▁▂▃▄▅▆▇█"),
117}
118"characters for sparklines in different formats"
121def array_info(
122 A: Any,
123 hist_bins: int = 5,
124) -> Dict[str, Any]:
125 """Extract statistical information from an array-like object.
127 # Parameters:
128 - `A : array-like`
129 Array to analyze (numpy array or torch tensor)
131 # Returns:
132 - `Dict[str, Any]`
133 Dictionary containing raw statistical information with numeric values
134 """
135 result: Dict[str, Any] = {
136 "is_tensor": None,
137 "device": None,
138 "requires_grad": None,
139 "shape": None,
140 "dtype": None,
141 "size": None,
142 "has_nans": None,
143 "nan_count": None,
144 "nan_percent": None,
145 "min": None,
146 "max": None,
147 "range": None,
148 "mean": None,
149 "std": None,
150 "median": None,
151 "histogram": None,
152 "bins": None,
153 "status": None,
154 }
156 # Check if it's a tensor by looking at its class name
157 # This avoids importing torch directly
158 A_type: str = type(A).__name__
159 result["is_tensor"] = A_type == "Tensor"
161 # Try to get device information if it's a tensor
162 if result["is_tensor"]:
163 try:
164 result["device"] = str(getattr(A, "device", None))
165 except: # noqa: E722
166 pass
168 # Convert to numpy array for calculations
169 try:
170 # For PyTorch tensors
171 if result["is_tensor"]:
172 # Check if tensor is on GPU
173 is_cuda: bool = False
174 try:
175 is_cuda = bool(getattr(A, "is_cuda", False))
176 except: # noqa: E722
177 pass
179 if is_cuda:
180 try:
181 # Try to get CPU tensor first
182 cpu_tensor = getattr(A, "cpu", lambda: A)()
183 except: # noqa: E722
184 A_np = np.array([])
185 else:
186 cpu_tensor = A
187 try:
188 # For CPU tensor, just detach and convert
189 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
190 A_np = getattr(detached, "numpy", lambda: np.array([]))()
191 except: # noqa: E722
192 A_np = np.array([])
193 else:
194 # For numpy arrays and other array-like objects
195 A_np = np.asarray(A)
196 except: # noqa: E722
197 A_np = np.array([])
199 # Get basic information
200 try:
201 result["shape"] = A_np.shape
202 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
203 result["size"] = A_np.size
204 result["requires_grad"] = getattr(A, "requires_grad", None)
205 except: # noqa: E722
206 pass
208 # If array is empty, return early
209 if result["size"] == 0:
210 result["status"] = "empty array"
211 return result
213 # Flatten array for statistics if it's multi-dimensional
214 try:
215 if len(A_np.shape) > 1:
216 A_flat = A_np.flatten()
217 else:
218 A_flat = A_np
219 except: # noqa: E722
220 A_flat = A_np
222 # Check for NaN values
223 try:
224 nan_mask = np.isnan(A_flat)
225 result["nan_count"] = np.sum(nan_mask)
226 result["has_nans"] = result["nan_count"] > 0
227 if result["size"] > 0:
228 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
229 except: # noqa: E722
230 pass
232 # If all values are NaN, return early
233 if result["has_nans"] and result["nan_count"] == result["size"]:
234 result["status"] = "all NaN"
235 return result
237 # Calculate statistics
238 try:
239 if result["has_nans"]:
240 result["min"] = float(np.nanmin(A_flat))
241 result["max"] = float(np.nanmax(A_flat))
242 result["mean"] = float(np.nanmean(A_flat))
243 result["std"] = float(np.nanstd(A_flat))
244 result["median"] = float(np.nanmedian(A_flat))
245 result["range"] = (result["min"], result["max"])
247 # Remove NaNs for histogram
248 A_hist = A_flat[~nan_mask]
249 else:
250 result["min"] = float(np.min(A_flat))
251 result["max"] = float(np.max(A_flat))
252 result["mean"] = float(np.mean(A_flat))
253 result["std"] = float(np.std(A_flat))
254 result["median"] = float(np.median(A_flat))
255 result["range"] = (result["min"], result["max"])
257 A_hist = A_flat
259 # Calculate histogram data for sparklines
260 if A_hist.size > 0:
261 try:
262 # TODO: handle bool tensors correctly
263 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
264 hist, bins = np.histogram(A_hist, bins=hist_bins)
265 result["histogram"] = hist
266 result["bins"] = bins
267 except: # noqa: E722
268 pass
270 result["status"] = "ok"
271 except Exception as e:
272 result["status"] = f"error: {str(e)}"
274 return result
277def generate_sparkline(
278 histogram: np.ndarray,
279 format: Literal["unicode", "latex", "ascii"] = "unicode",
280 log_y: Optional[bool] = None,
281) -> tuple[str, bool]:
282 """Generate a sparkline visualization of the histogram.
284 # Parameters:
285 - `histogram : np.ndarray`
286 Histogram data
287 - `format : Literal["unicode", "latex", "ascii"]`
288 Output format (defaults to `"unicode"`)
289 - `log_y : bool|None`
290 Whether to use logarithmic y-scale. `None` for automatic detection
291 (defaults to `None`)
293 # Returns:
294 - `tuple[str, bool]`
295 Sparkline visualization and whether log scale was used
296 """
297 if histogram is None or len(histogram) == 0:
298 return "", False
300 # Get the appropriate character set
301 chars: List[str]
302 if format in SPARK_CHARS:
303 chars = SPARK_CHARS[format]
304 else:
305 chars = SPARK_CHARS["ascii"]
307 # automatic detection of log_y
308 if log_y is None:
309 # we bin the histogram values to the number of levels in our sparkline characters
310 hist_hist = np.histogram(histogram, bins=len(chars))[0]
311 # if every bin except the smallest (first) and largest (last) is empty,
312 # then we should use the log scale. if those bins are nonempty, keep the linear scale
313 if hist_hist[1:-1].max() > 0:
314 log_y = False
315 else:
316 log_y = True
318 # Handle log scale
319 if log_y:
320 # Add small value to avoid log(0)
321 hist_data = np.log1p(histogram)
322 else:
323 hist_data = histogram
325 # Normalize to character set range
326 if hist_data.max() > 0:
327 normalized = hist_data / hist_data.max() * (len(chars) - 1)
328 else:
329 normalized = np.zeros_like(hist_data)
331 # Convert to characters
332 spark = ""
333 for val in normalized:
334 idx = round(val)
335 spark += chars[idx]
337 return spark, log_y
340DEFAULT_SETTINGS: Dict[str, Any] = dict(
341 fmt="unicode",
342 precision=2,
343 stats=True,
344 shape=True,
345 dtype=True,
346 device=True,
347 requires_grad=True,
348 sparkline=False,
349 sparkline_bins=5,
350 sparkline_logy=None,
351 colored=False,
352 as_list=False,
353 eq_char="=",
354)
357def apply_color(
358 text: str, color_key: str, colors: Dict[str, str], using_tex: bool
359) -> str:
360 if using_tex:
361 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text
362 else:
363 return (
364 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
365 )
368def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
369 """Colorize dtype string with specific colors for torch and type names."""
371 # Handle torch prefix
372 type_part: str = dtype_str
373 prefix_part: Optional[str] = None
374 if "torch." in dtype_str:
375 parts = dtype_str.split("torch.")
376 if len(parts) == 2:
377 prefix_part = apply_color("torch", "torch", colors, using_tex)
378 type_part = parts[1]
380 # Handle type coloring
381 color_key: str = "dtype"
382 if "bool" in dtype_str.lower():
383 color_key = "dtype_bool"
384 elif "int" in dtype_str.lower():
385 color_key = "dtype_int"
386 elif "float" in dtype_str.lower():
387 color_key = "dtype_float"
389 type_colored: str = apply_color(type_part, color_key, colors, using_tex)
391 if prefix_part:
392 return f"{prefix_part}.{type_colored}"
393 else:
394 return type_colored
397def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str:
398 """Format shape with proper coloring for both 1D and multi-D arrays."""
400 def apply_color(text: str, color_key: str) -> str:
401 if using_tex:
402 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text
403 else:
404 return (
405 f"{colors[color_key]}{text}{colors['reset']}"
406 if colors[color_key]
407 else text
408 )
410 if len(shape_val) == 1:
411 # For 1D arrays, still color the dimension value
412 return apply_color(str(shape_val[0]), "shape")
413 else:
414 # For multi-D arrays, color each dimension
415 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"
418def format_device_colored(
419 device_str: str, colors: Dict[str, str], using_tex: bool
420) -> str:
421 """Format device string with CUDA highlighting."""
423 def apply_color(text: str, color_key: str) -> str:
424 if using_tex:
425 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text
426 else:
427 return (
428 f"{colors[color_key]}{text}{colors['reset']}"
429 if colors[color_key]
430 else text
431 )
433 if "cuda" in device_str.lower():
434 return apply_color(device_str, "device_cuda")
435 else:
436 return apply_color(device_str, "device")
439class _UseDefaultType:
440 pass
443_USE_DEFAULT = _UseDefaultType()
446@overload
447def array_summary(
448 as_list: Literal[True],
449 **kwargs,
450) -> List[str]: ...
451@overload
452def array_summary(
453 as_list: Literal[False],
454 **kwargs,
455) -> str: ...
456def array_summary( # type: ignore[misc]
457 array,
458 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment]
459 precision: int = _USE_DEFAULT, # type: ignore[assignment]
460 stats: bool = _USE_DEFAULT, # type: ignore[assignment]
461 shape: bool = _USE_DEFAULT, # type: ignore[assignment]
462 dtype: bool = _USE_DEFAULT, # type: ignore[assignment]
463 device: bool = _USE_DEFAULT, # type: ignore[assignment]
464 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment]
465 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment]
466 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment]
467 sparkline_logy: Optional[bool] = _USE_DEFAULT, # type: ignore[assignment]
468 colored: bool = _USE_DEFAULT, # type: ignore[assignment]
469 eq_char: str = _USE_DEFAULT, # type: ignore[assignment]
470 as_list: bool = _USE_DEFAULT, # type: ignore[assignment]
471) -> Union[str, List[str]]:
472 """Format array information into a readable summary.
474 # Parameters:
475 - `array`
476 array-like object (numpy array or torch tensor)
477 - `precision : int`
478 Decimal places (defaults to `2`)
479 - `format : Literal["unicode", "latex", "ascii"]`
480 Output format (defaults to `{default_fmt}`)
481 - `stats : bool`
482 Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
483 - `shape : bool`
484 Whether to include shape info (defaults to `True`)
485 - `dtype : bool`
486 Whether to include dtype info (defaults to `True`)
487 - `device : bool`
488 Whether to include device info for torch tensors (defaults to `True`)
489 - `requires_grad : bool`
490 Whether to include requires_grad info for torch tensors (defaults to `True`)
491 - `sparkline : bool`
492 Whether to include a sparkline visualization (defaults to `False`)
493 - `sparkline_width : int`
494 Width of the sparkline (defaults to `20`)
495 - `sparkline_logy : bool|None`
496 Whether to use logarithmic y-scale for sparkline (defaults to `None`)
497 - `colored : bool`
498 Whether to add color to output (defaults to `False`)
499 - `as_list : bool`
500 Whether to return as list of strings instead of joined string (defaults to `False`)
502 # Returns:
503 - `Union[str, List[str]]`
504 Formatted statistical summary, either as string or list of strings
505 """
506 if fmt is _USE_DEFAULT:
507 fmt = DEFAULT_SETTINGS["fmt"]
508 if precision is _USE_DEFAULT:
509 precision = DEFAULT_SETTINGS["precision"]
510 if stats is _USE_DEFAULT:
511 stats = DEFAULT_SETTINGS["stats"]
512 if shape is _USE_DEFAULT:
513 shape = DEFAULT_SETTINGS["shape"]
514 if dtype is _USE_DEFAULT:
515 dtype = DEFAULT_SETTINGS["dtype"]
516 if device is _USE_DEFAULT:
517 device = DEFAULT_SETTINGS["device"]
518 if requires_grad is _USE_DEFAULT:
519 requires_grad = DEFAULT_SETTINGS["requires_grad"]
520 if sparkline is _USE_DEFAULT:
521 sparkline = DEFAULT_SETTINGS["sparkline"]
522 if sparkline_bins is _USE_DEFAULT:
523 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
524 if sparkline_logy is _USE_DEFAULT:
525 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
526 if colored is _USE_DEFAULT:
527 colored = DEFAULT_SETTINGS["colored"]
528 if as_list is _USE_DEFAULT:
529 as_list = DEFAULT_SETTINGS["as_list"]
530 if eq_char is _USE_DEFAULT:
531 eq_char = DEFAULT_SETTINGS["eq_char"]
533 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
534 result_parts: List[str] = []
535 using_tex: bool = fmt == "latex"
537 # Set color scheme based on format and colored flag
538 colors: Dict[str, str]
539 if colored:
540 colors = COLORS["latex"] if using_tex else COLORS["terminal"]
541 else:
542 colors = COLORS["none"]
544 # Get symbols for the current format
545 symbols: Dict[str, str] = SYMBOLS[fmt]
547 # Helper function to colorize text
548 def colorize(text: str, color_key: str) -> str:
549 if using_tex:
550 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text
551 else:
552 return (
553 f"{colors[color_key]}{text}{colors['reset']}"
554 if colors[color_key]
555 else text
556 )
558 # Check if dtype is integer type
559 dtype_str: str = array_data.get("dtype", "")
560 is_int_dtype: bool = any(
561 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
562 )
564 # Format string for numbers
565 float_fmt: str = f".{precision}f"
567 # Handle error status or empty array
568 if (
569 array_data["status"] in ["empty array", "all NaN", "unknown"]
570 or array_data["size"] == 0
571 ):
572 status = array_data["status"]
573 result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
574 else:
575 # Add NaN warning at the beginning if there are NaNs
576 if array_data["has_nans"]:
577 _percent: str = "\\%" if using_tex else "%"
578 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
579 result_parts.append(colorize(nan_str, "warning"))
581 # Statistics
582 if stats:
583 for stat_key in ["mean", "std", "median"]:
584 if array_data[stat_key] is not None:
585 stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
586 stat_colored: str = colorize(stat_str, stat_key)
587 result_parts.append(f"{symbols[stat_key]}={stat_colored}")
589 # Range (min, max)
590 if array_data["range"] is not None:
591 min_val, max_val = array_data["range"]
592 if is_int_dtype:
593 min_str: str = f"{int(min_val):d}"
594 max_str: str = f"{int(max_val):d}"
595 else:
596 min_str = f"{min_val:{float_fmt}}"
597 max_str = f"{max_val:{float_fmt}}"
598 min_colored: str = colorize(min_str, "range")
599 max_colored: str = colorize(max_str, "range")
600 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
601 result_parts.append(range_str)
603 # Add sparkline if requested
604 if sparkline and array_data["histogram"] is not None:
605 # this should return whether log_y is used or not and then we set the symbol accordingly
606 spark, used_log = generate_sparkline(
607 array_data["histogram"],
608 format=fmt,
609 log_y=sparkline_logy,
610 )
611 if spark:
612 spark_colored = colorize(spark, "sparkline")
613 dist_symbol = (
614 symbols["distribution_log"] if used_log else symbols["distribution"]
615 )
616 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")
618 # Add shape if requested
619 if shape and array_data["shape"]:
620 shape_val = array_data["shape"]
621 shape_str = format_shape_colored(shape_val, colors, using_tex)
622 result_parts.append(f"shape{eq_char}{shape_str}")
624 # Add dtype if requested
625 if dtype and array_data["dtype"]:
626 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
627 result_parts.append(f"dtype={dtype_colored}")
629 # Add device if requested and it's a tensor with device info
630 if device and array_data["is_tensor"] and array_data["device"]:
631 device_colored = format_device_colored(array_data["device"], colors, using_tex)
632 result_parts.append(f"device{eq_char}{device_colored}")
634 # Add gradient info
635 if requires_grad and array_data["is_tensor"]:
636 bool_req_grad_symb: str = (
637 symbols["true"] if array_data["requires_grad"] else symbols["false"]
638 )
639 result_parts.append(
640 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
641 )
643 # Return as list if requested, otherwise join with spaces
644 if as_list:
645 return result_parts
646 else:
647 joinchar: str = r" \quad " if using_tex else " "
648 return joinchar.join(result_parts)