Coverage for muutils / tensor_info.py: 89%
266 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1"get metadata about a tensor, mostly for `muutils.dbg`"
3from __future__ import annotations
5import numpy as np
6from typing import Union, Any, Literal, List, Dict, overload, Optional, TYPE_CHECKING
8if TYPE_CHECKING:
9 from typing import TypedDict
10else:
11 try:
12 from typing import TypedDict
13 except ImportError:
14 from typing_extensions import TypedDict
16# Global color definitions
17COLORS: Dict[str, Dict[str, str]] = {
18 "latex": {
19 "range": r"\textcolor{purple}",
20 "mean": r"\textcolor{teal}",
21 "std": r"\textcolor{orange}",
22 "median": r"\textcolor{green}",
23 "warning": r"\textcolor{red}",
24 "shape": r"\textcolor{magenta}",
25 "dtype": r"\textcolor{gray}",
26 "device": r"\textcolor{gray}",
27 "requires_grad": r"\textcolor{gray}",
28 "sparkline": r"\textcolor{blue}",
29 "torch": r"\textcolor{orange}",
30 "dtype_bool": r"\textcolor{gray}",
31 "dtype_int": r"\textcolor{blue}",
32 "dtype_float": r"\textcolor{red!70}", # 70% red intensity
33 "dtype_str": r"\textcolor{red}",
34 "device_cuda": r"\textcolor{green}",
35 "reset": "",
36 },
37 "terminal": {
38 "range": "\033[35m", # purple
39 "mean": "\033[36m", # cyan/teal
40 "std": "\033[33m", # yellow/orange
41 "median": "\033[32m", # green
42 "warning": "\033[31m", # red
43 "shape": "\033[95m", # bright magenta
44 "dtype": "\033[90m", # gray
45 "device": "\033[90m", # gray
46 "requires_grad": "\033[90m", # gray
47 "sparkline": "\033[34m", # blue
48 "torch": "\033[38;5;208m", # bright orange
49 "dtype_bool": "\033[38;5;245m", # medium grey
50 "dtype_int": "\033[38;5;39m", # bright blue
51 "dtype_float": "\033[38;5;167m", # softer red/coral
52 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green
53 "reset": "\033[0m",
54 },
55 "none": {
56 "range": "",
57 "mean": "",
58 "std": "",
59 "median": "",
60 "warning": "",
61 "shape": "",
62 "dtype": "",
63 "device": "",
64 "requires_grad": "",
65 "sparkline": "",
66 "torch": "",
67 "dtype_bool": "",
68 "dtype_int": "",
69 "dtype_float": "",
70 "dtype_str": "",
71 "device_cuda": "",
72 "reset": "",
73 },
74}
76OutputFormat = Literal["unicode", "latex", "ascii"]
79class ArraySummarySettings(TypedDict):
80 """Type definition for array_summary default settings."""
82 fmt: OutputFormat
83 precision: int
84 stats: bool
85 shape: bool
86 dtype: bool
87 device: bool
88 requires_grad: bool
89 sparkline: bool
90 sparkline_bins: int
91 sparkline_logy: Optional[bool]
92 colored: bool
93 as_list: bool
94 eq_char: str
97SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
98 "latex": {
99 "range": r"\mathcal{R}",
100 "mean": r"\mu",
101 "std": r"\sigma",
102 "median": r"\tilde{x}",
103 "distribution": r"\mathbb{P}",
104 "distribution_log": r"\mathbb{P}_L",
105 "nan_values": r"\text{NANvals}",
106 "warning": "!!!",
107 "requires_grad": r"\nabla",
108 "true": r"\checkmark",
109 "false": r"\times",
110 },
111 "unicode": {
112 "range": "R",
113 "mean": "μ",
114 "std": "σ",
115 "median": "x̃",
116 "distribution": "ℙ",
117 "distribution_log": "ℙ˪",
118 "nan_values": "NANvals",
119 "warning": "🚨",
120 "requires_grad": "∇",
121 "true": "✓",
122 "false": "✗",
123 },
124 "ascii": {
125 "range": "range",
126 "mean": "mean",
127 "std": "std",
128 "median": "med",
129 "distribution": "dist",
130 "distribution_log": "dist_log",
131 "nan_values": "NANvals",
132 "warning": "!!!",
133 "requires_grad": "requires_grad",
134 "true": "1",
135 "false": "0",
136 },
137}
138"Symbols for different formats"
140SPARK_CHARS: Dict[OutputFormat, List[str]] = {
141 "unicode": list(" ▁▂▃▄▅▆▇█"),
142 "ascii": list(" _.-~=#"),
143 "latex": list(" ▁▂▃▄▅▆▇█"),
144}
145"characters for sparklines in different formats"
148def array_info(
149 A: Any,
150 hist_bins: int = 5,
151) -> Dict[str, Any]:
152 """Extract statistical information from an array-like object.
154 # Parameters:
155 - `A : array-like`
156 Array to analyze (numpy array or torch tensor)
158 # Returns:
159 - `Dict[str, Any]`
160 Dictionary containing raw statistical information with numeric values
161 """
162 result: Dict[str, Any] = {
163 "is_tensor": None,
164 "device": None,
165 "requires_grad": None,
166 "shape": None,
167 "dtype": None,
168 "size": None,
169 "has_nans": None,
170 "nan_count": None,
171 "nan_percent": None,
172 "min": None,
173 "max": None,
174 "range": None,
175 "mean": None,
176 "std": None,
177 "median": None,
178 "histogram": None,
179 "bins": None,
180 "status": None,
181 }
183 # Check if it's a tensor by looking at its class name
184 # This avoids importing torch directly
185 A_type: str = type(A).__name__
186 result["is_tensor"] = A_type == "Tensor"
188 # Try to get device information if it's a tensor
189 if result["is_tensor"]:
190 try:
191 result["device"] = str(getattr(A, "device", None))
192 except: # noqa: E722
193 pass
195 # Convert to numpy array for calculations
196 try:
197 # For PyTorch tensors
198 if result["is_tensor"]:
199 # Check if tensor is on GPU
200 is_cuda: bool = False
201 try:
202 is_cuda = bool(getattr(A, "is_cuda", False))
203 except: # noqa: E722
204 pass
206 if is_cuda:
207 try:
208 # Try to get CPU tensor first
209 cpu_tensor = getattr(A, "cpu", lambda: A)()
210 except: # noqa: E722
211 A_np = np.array([])
212 else:
213 cpu_tensor = A
214 try:
215 # For CPU tensor, just detach and convert
216 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() # pyright: ignore[reportPossiblyUnboundVariable]
217 A_np = getattr(detached, "numpy", lambda: np.array([]))()
218 except: # noqa: E722
219 A_np = np.array([])
220 else:
221 # For numpy arrays and other array-like objects
222 A_np = np.asarray(A)
223 except: # noqa: E722
224 A_np = np.array([])
226 # Get basic information
227 try:
228 result["shape"] = A_np.shape
229 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
230 result["size"] = A_np.size
231 result["requires_grad"] = getattr(A, "requires_grad", None)
232 except: # noqa: E722
233 pass
235 # If array is empty, return early
236 if result["size"] == 0:
237 result["status"] = "empty array"
238 return result
240 # Flatten array for statistics if it's multi-dimensional
241 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225
242 try:
243 if len(A_np.shape) > 1:
244 A_flat = A_np.flatten() # type: ignore[assignment]
245 else:
246 A_flat = A_np # type: ignore[assignment]
247 except: # noqa: E722
248 A_flat = A_np # type: ignore[assignment]
250 # Check for NaN values
251 try:
252 nan_mask = np.isnan(A_flat)
253 result["nan_count"] = np.sum(nan_mask)
254 result["has_nans"] = result["nan_count"] > 0
255 result_size: int = result["size"] # ty: ignore[invalid-assignment]
256 if result_size > 0:
257 result["nan_percent"] = (result["nan_count"] / result_size) * 100
258 except: # noqa: E722
259 pass
261 # If all values are NaN, return early
262 if result["has_nans"] and result["nan_count"] == result["size"]:
263 result["status"] = "all NaN"
264 return result
266 # Calculate statistics
267 try:
268 if result["has_nans"]:
269 result["min"] = float(np.nanmin(A_flat))
270 result["max"] = float(np.nanmax(A_flat))
271 result["mean"] = float(np.nanmean(A_flat))
272 result["std"] = float(np.nanstd(A_flat))
273 result["median"] = float(np.nanmedian(A_flat))
274 result["range"] = (result["min"], result["max"])
276 # Remove NaNs for histogram
277 # TYPING: nan mask will def be bound on this branch, idk why it thinks the operator is bad
278 A_hist = A_flat[~nan_mask] # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable]
279 else:
280 result["min"] = float(np.min(A_flat))
281 result["max"] = float(np.max(A_flat))
282 result["mean"] = float(np.mean(A_flat))
283 result["std"] = float(np.std(A_flat))
284 result["median"] = float(np.median(A_flat))
285 result["range"] = (result["min"], result["max"])
287 A_hist = A_flat
289 # Calculate histogram data for sparklines
290 if A_hist.size > 0:
291 try:
292 # TODO: handle bool tensors correctly
293 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
294 hist, bins = np.histogram(A_hist, bins=hist_bins)
295 result["histogram"] = hist
296 result["bins"] = bins
297 except: # noqa: E722
298 pass
300 result["status"] = "ok"
301 except Exception as e:
302 result["status"] = f"error: {str(e)}"
304 return result
307SparklineFormat = Literal["unicode", "latex", "ascii"]
310def generate_sparkline(
311 histogram: np.ndarray,
312 format: SparklineFormat = "unicode",
313 log_y: Optional[bool] = None,
314) -> tuple[str, bool]:
315 """Generate a sparkline visualization of the histogram.
317 # Parameters:
318 - `histogram : np.ndarray`
319 Histogram data
320 - `format : Literal["unicode", "latex", "ascii"]`
321 Output format (defaults to `"unicode"`)
322 - `log_y : bool|None`
323 Whether to use logarithmic y-scale. `None` for automatic detection
324 (defaults to `None`)
326 # Returns:
327 - `tuple[str, bool]`
328 Sparkline visualization and whether log scale was used
329 """
330 if histogram is None or len(histogram) == 0:
331 return "", False
333 # Get the appropriate character set
334 chars: List[str]
335 if format in SPARK_CHARS:
336 chars = SPARK_CHARS[format]
337 else:
338 chars = SPARK_CHARS["ascii"]
340 # automatic detection of log_y
341 if log_y is None:
342 # we bin the histogram values to the number of levels in our sparkline characters
343 hist_hist = np.histogram(histogram, bins=len(chars))[0]
344 # if every bin except the smallest (first) and largest (last) is empty,
345 # then we should use the log scale. if those bins are nonempty, keep the linear scale
346 if hist_hist[1:-1].max() > 0:
347 log_y = False
348 else:
349 log_y = True
351 # Handle log scale
352 if log_y:
353 # Add small value to avoid log(0)
354 hist_data = np.log1p(histogram)
355 else:
356 hist_data = histogram
358 # Normalize to character set range
359 if hist_data.max() > 0:
360 normalized = hist_data / hist_data.max() * (len(chars) - 1)
361 else:
362 normalized = np.zeros_like(hist_data)
364 # Convert to characters
365 spark = ""
366 for val in normalized:
367 idx = round(val)
368 spark += chars[idx]
370 return spark, log_y
373DEFAULT_SETTINGS: ArraySummarySettings = {
374 "fmt": "unicode",
375 "precision": 2,
376 "stats": True,
377 "shape": True,
378 "dtype": True,
379 "device": True,
380 "requires_grad": True,
381 "sparkline": False,
382 "sparkline_bins": 5,
383 "sparkline_logy": None,
384 "colored": False,
385 "as_list": False,
386 "eq_char": "=",
387}
390def apply_color(
391 text: str, color_key: str, colors: Dict[str, str], using_tex: bool
392) -> str:
393 if using_tex:
394 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
395 else:
396 return (
397 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
398 )
401def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
402 """Colorize dtype string with specific colors for torch and type names."""
404 # Handle torch prefix
405 type_part: str = dtype_str
406 prefix_part: Optional[str] = None
407 if "torch." in dtype_str:
408 parts = dtype_str.split("torch.")
409 if len(parts) == 2:
410 prefix_part = apply_color("torch", "torch", colors, using_tex)
411 type_part = parts[1]
413 # Handle type coloring
414 color_key: str = "dtype"
415 if "bool" in dtype_str.lower():
416 color_key = "dtype_bool"
417 elif "int" in dtype_str.lower():
418 color_key = "dtype_int"
419 elif "float" in dtype_str.lower():
420 color_key = "dtype_float"
422 type_colored: str = apply_color(type_part, color_key, colors, using_tex)
424 if prefix_part:
425 return f"{prefix_part}.{type_colored}"
426 else:
427 return type_colored
430def format_shape_colored(
431 shape_val: Any, colors: Dict[str, str], using_tex: bool
432) -> str:
433 """Format shape with proper coloring for both 1D and multi-D arrays."""
435 def apply_color(text: str, color_key: str) -> str:
436 if using_tex:
437 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
438 else:
439 return (
440 f"{colors[color_key]}{text}{colors['reset']}"
441 if colors[color_key]
442 else text
443 )
445 if len(shape_val) == 1:
446 # For 1D arrays, still color the dimension value
447 return apply_color(str(shape_val[0]), "shape")
448 else:
449 # For multi-D arrays, color each dimension
450 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"
453def format_device_colored(
454 device_str: str, colors: Dict[str, str], using_tex: bool
455) -> str:
456 """Format device string with CUDA highlighting."""
458 def apply_color(text: str, color_key: str) -> str:
459 if using_tex:
460 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
461 else:
462 return (
463 f"{colors[color_key]}{text}{colors['reset']}"
464 if colors[color_key]
465 else text
466 )
468 if "cuda" in device_str.lower():
469 return apply_color(device_str, "device_cuda")
470 else:
471 return apply_color(device_str, "device")
474class _UseDefaultType:
475 pass
478_USE_DEFAULT = _UseDefaultType()
481@overload
482def array_summary(
483 array: Any,
484 fmt: OutputFormat = ...,
485 precision: int = ...,
486 stats: bool = ...,
487 shape: bool = ...,
488 dtype: bool = ...,
489 device: bool = ...,
490 requires_grad: bool = ...,
491 sparkline: bool = ...,
492 sparkline_bins: int = ...,
493 sparkline_logy: Optional[bool] = ...,
494 colored: bool = ...,
495 eq_char: str = ...,
496 *,
497 as_list: Literal[True],
498) -> List[str]: ...
499@overload
500def array_summary(
501 array: Any,
502 fmt: OutputFormat = ...,
503 precision: int = ...,
504 stats: bool = ...,
505 shape: bool = ...,
506 dtype: bool = ...,
507 device: bool = ...,
508 requires_grad: bool = ...,
509 sparkline: bool = ...,
510 sparkline_bins: int = ...,
511 sparkline_logy: Optional[bool] = ...,
512 colored: bool = ...,
513 eq_char: str = ...,
514 as_list: Literal[False] = ...,
515) -> str: ...
516@overload
517def array_summary(
518 array: Any,
519 fmt: OutputFormat = ...,
520 precision: int = ...,
521 stats: bool = ...,
522 shape: bool = ...,
523 dtype: bool = ...,
524 device: bool = ...,
525 requires_grad: bool = ...,
526 sparkline: bool = ...,
527 sparkline_bins: int = ...,
528 sparkline_logy: Optional[bool] = ...,
529 colored: bool = ...,
530 eq_char: str = ...,
531 as_list: bool = ...,
532) -> Union[str, List[str]]: ...
533def array_summary(
534 array: Any,
535 fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT,
536 precision: Union[int, _UseDefaultType] = _USE_DEFAULT,
537 stats: Union[bool, _UseDefaultType] = _USE_DEFAULT,
538 shape: Union[bool, _UseDefaultType] = _USE_DEFAULT,
539 dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT,
540 device: Union[bool, _UseDefaultType] = _USE_DEFAULT,
541 requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT,
542 sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT,
543 sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT,
544 sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT,
545 colored: Union[bool, _UseDefaultType] = _USE_DEFAULT,
546 eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT,
547 as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT,
548) -> Union[str, List[str]]:
549 """Format array information into a readable summary.
551 # Parameters:
552 - `array`
553 array-like object (numpy array or torch tensor)
554 - `precision : int`
555 Decimal places (defaults to `2`)
556 - `format : Literal["unicode", "latex", "ascii"]`
557 Output format (defaults to `{default_fmt}`)
558 - `stats : bool`
559 Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
560 - `shape : bool`
561 Whether to include shape info (defaults to `True`)
562 - `dtype : bool`
563 Whether to include dtype info (defaults to `True`)
564 - `device : bool`
565 Whether to include device info for torch tensors (defaults to `True`)
566 - `requires_grad : bool`
567 Whether to include requires_grad info for torch tensors (defaults to `True`)
568 - `sparkline : bool`
569 Whether to include a sparkline visualization (defaults to `False`)
570 - `sparkline_width : int`
571 Width of the sparkline (defaults to `20`)
572 - `sparkline_logy : bool|None`
573 Whether to use logarithmic y-scale for sparkline (defaults to `None`)
574 - `colored : bool`
575 Whether to add color to output (defaults to `False`)
576 - `as_list : bool`
577 Whether to return as list of strings instead of joined string (defaults to `False`)
579 # Returns:
580 - `Union[str, List[str]]`
581 Formatted statistical summary, either as string or list of strings
582 """
583 if isinstance(fmt, _UseDefaultType):
584 fmt = DEFAULT_SETTINGS["fmt"]
585 if isinstance(precision, _UseDefaultType):
586 precision = DEFAULT_SETTINGS["precision"]
587 if isinstance(stats, _UseDefaultType):
588 stats = DEFAULT_SETTINGS["stats"]
589 if isinstance(shape, _UseDefaultType):
590 shape = DEFAULT_SETTINGS["shape"]
591 if isinstance(dtype, _UseDefaultType):
592 dtype = DEFAULT_SETTINGS["dtype"]
593 if isinstance(device, _UseDefaultType):
594 device = DEFAULT_SETTINGS["device"]
595 if isinstance(requires_grad, _UseDefaultType):
596 requires_grad = DEFAULT_SETTINGS["requires_grad"]
597 if isinstance(sparkline, _UseDefaultType):
598 sparkline = DEFAULT_SETTINGS["sparkline"]
599 if isinstance(sparkline_bins, _UseDefaultType):
600 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
601 if isinstance(sparkline_logy, _UseDefaultType):
602 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
603 if isinstance(colored, _UseDefaultType):
604 colored = DEFAULT_SETTINGS["colored"]
605 if isinstance(as_list, _UseDefaultType):
606 as_list = DEFAULT_SETTINGS["as_list"]
607 if isinstance(eq_char, _UseDefaultType):
608 eq_char = DEFAULT_SETTINGS["eq_char"]
610 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
611 result_parts: List[str] = []
612 using_tex: bool = fmt == "latex"
614 # Set color scheme based on format and colored flag
615 colors: Dict[str, str]
616 if colored:
617 colors = COLORS["latex"] if using_tex else COLORS["terminal"]
618 else:
619 colors = COLORS["none"]
621 # Get symbols for the current format
622 symbols: Dict[str, str] = SYMBOLS[fmt]
624 # Helper function to colorize text
625 def colorize(text: str, color_key: str) -> str:
626 if using_tex:
627 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
628 else:
629 return (
630 f"{colors[color_key]}{text}{colors['reset']}"
631 if colors[color_key]
632 else text
633 )
635 # Check if dtype is integer type
636 dtype_str: str = array_data.get("dtype", "")
637 is_int_dtype: bool = any(
638 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
639 )
641 # Format string for numbers
642 float_fmt: str = f".{precision}f"
644 # Handle error status or empty array
645 if (
646 array_data["status"] in ["empty array", "all NaN", "unknown"]
647 or array_data["size"] == 0
648 ):
649 status = array_data["status"]
650 result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
651 else:
652 # Add NaN warning at the beginning if there are NaNs
653 if array_data["has_nans"]:
654 _percent: str = "\\%" if using_tex else "%"
655 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
656 result_parts.append(colorize(nan_str, "warning"))
658 # Statistics
659 if stats:
660 for stat_key in ["mean", "std", "median"]:
661 if array_data[stat_key] is not None:
662 stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
663 stat_colored: str = colorize(stat_str, stat_key)
664 result_parts.append(f"{symbols[stat_key]}={stat_colored}")
666 # Range (min, max)
667 if array_data["range"] is not None:
668 min_val, max_val = array_data["range"]
669 if is_int_dtype:
670 min_str: str = f"{int(min_val):d}"
671 max_str: str = f"{int(max_val):d}"
672 else:
673 min_str = f"{min_val:{float_fmt}}"
674 max_str = f"{max_val:{float_fmt}}"
675 min_colored: str = colorize(min_str, "range")
676 max_colored: str = colorize(max_str, "range")
677 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
678 result_parts.append(range_str)
680 # Add sparkline if requested
681 if sparkline and array_data["histogram"] is not None:
682 # this should return whether log_y is used or not and then we set the symbol accordingly
683 spark, used_log = generate_sparkline(
684 array_data["histogram"],
685 format=fmt,
686 log_y=sparkline_logy,
687 )
688 if spark:
689 spark_colored = colorize(spark, "sparkline")
690 dist_symbol = (
691 symbols["distribution_log"] if used_log else symbols["distribution"]
692 )
693 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")
695 # Add shape if requested
696 if shape and array_data["shape"]:
697 shape_val = array_data["shape"]
698 shape_str = format_shape_colored(shape_val, colors, using_tex)
699 result_parts.append(f"shape{eq_char}{shape_str}")
701 # Add dtype if requested
702 if dtype and array_data["dtype"]:
703 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
704 result_parts.append(f"dtype={dtype_colored}")
706 # Add device if requested and it's a tensor with device info
707 if device and array_data["is_tensor"] and array_data["device"]:
708 device_colored = format_device_colored(array_data["device"], colors, using_tex)
709 result_parts.append(f"device{eq_char}{device_colored}")
711 # Add gradient info
712 if requires_grad and array_data["is_tensor"]:
713 bool_req_grad_symb: str = (
714 symbols["true"] if array_data["requires_grad"] else symbols["false"]
715 )
716 result_parts.append(
717 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
718 )
720 # Return as list if requested, otherwise join with spaces
721 if as_list:
722 return result_parts
723 else:
724 joinchar: str = r" \quad " if using_tex else " "
725 return joinchar.join(result_parts)