Coverage for tests/unit/test_tensor_info.py: 91%
95 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:39 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:39 -0600
1from __future__ import annotations
3from pathlib import Path
4from typing import Any, List, Tuple
6import numpy as np
7import pytest
8import torch
10import muutils.tensor_info as tensor_info
11from muutils.tensor_info import array_summary, generate_sparkline
13TEMP_PATH: Path = Path("tests/_temp")
15# Check if torch supports the "meta" device.
16meta_supported: bool = False
17if torch is not None:
18 try:
19 torch.empty(1, device="meta")
20 meta_supported = True
21 except Exception:
22 meta_supported = False
25# Helper function to generate an input based on type and a flag for tensor's requires_grad.
26def generate_input(input_type: str, tensor_requires_grad: bool) -> Any:
27 """
28 Generate an input array or tensor according to input_type.
30 Parameters:
31 - `input_type : str`
32 Must be one of:
33 "numpy_normal", "numpy_with_nan", "numpy_empty",
34 "torch_cpu", "torch_cpu_nan", "torch_meta", "torch_meta_nan"
35 - `tensor_requires_grad : bool`
36 For torch arrays, set requires_grad accordingly (ignored for numpy).
38 Returns:
39 - Array-like input.
40 """
41 if input_type.startswith("numpy"):
42 if input_type == "numpy_normal":
43 return np.array([1, 2, 3, 4, 5])
44 elif input_type == "numpy_with_nan":
45 return np.array([np.nan, 1, 2, np.nan, 3])
46 elif input_type == "numpy_empty":
47 return np.array([])
48 else:
49 raise ValueError("Unknown numpy input type")
50 elif torch is not None and input_type.startswith("torch"):
51 if "cpu" in input_type:
52 device_str: str = "cpu"
53 elif "meta" in input_type:
54 device_str = "meta"
55 else:
56 device_str = "cpu"
57 if "with_nan" in input_type:
58 data = [float("nan"), 1.0, 2.0, float("nan"), 3.0]
59 else:
60 data = [1.0, 2.0, 3.0, 4.0, 5.0]
61 return torch.tensor(data, device=device_str, requires_grad=tensor_requires_grad)
62 else:
63 raise ValueError("Unknown input type or torch not available")
66# Define option dictionaries covering a variety of settings.
67option_dicts: List[dict] = [
68 # All defaults (most verbose, no sparkline)
69 {
70 "fmt": "unicode",
71 "sparkline": False,
72 "colored": False,
73 "as_list": False,
74 "stats": True,
75 "shape": True,
76 "dtype": True,
77 "device": True,
78 "call_requires_grad": True,
79 "eq_char": "=",
80 "sparkline_bins": 5,
81 "sparkline_logy": False,
82 },
83 # Turn off most extra info; sparkline on with different bin count and log scale
84 {
85 "fmt": "latex",
86 "sparkline": True,
87 "colored": True,
88 "as_list": True,
89 "stats": False,
90 "shape": False,
91 "dtype": False,
92 "device": False,
93 "call_requires_grad": False,
94 "eq_char": ":",
95 "sparkline_bins": 10,
96 "sparkline_logy": True,
97 },
98 # Mixed options with ascii format and some extras off
99 {
100 "fmt": "ascii",
101 "sparkline": True,
102 "colored": False,
103 "as_list": False,
104 "stats": True,
105 "shape": False,
106 "dtype": True,
107 "device": False,
108 "call_requires_grad": True,
109 "eq_char": "=",
110 "sparkline_bins": 5,
111 "sparkline_logy": True,
112 },
113 # All features on, but no gradient info requested in the summary call.
114 {
115 "fmt": "unicode",
116 "sparkline": True,
117 "colored": True,
118 "as_list": True,
119 "stats": True,
120 "shape": True,
121 "dtype": True,
122 "device": True,
123 "call_requires_grad": False,
124 "eq_char": ":",
125 "sparkline_bins": 10,
126 "sparkline_logy": False,
127 },
128]
130# Build a list of (input_type, tensor_requires_grad) tuples.
131# For numpy inputs, the tensor_requires_grad flag is irrelevant.
132input_params: List[Tuple[str, bool]] = [
133 ("numpy_normal", False),
134 ("numpy_with_nan", False),
135 ("numpy_empty", False),
136]
138if torch is not None:
139 # Torch CPU inputs: test both with and without grad.
140 input_params.extend(
141 [
142 ("torch_cpu", True),
143 ("torch_cpu", False),
144 ("torch_cpu_nan", True),
145 ("torch_cpu_nan", False),
146 ]
147 )
148 if meta_supported:
149 input_params.extend(
150 [
151 ("torch_meta", True),
152 ("torch_meta", False),
153 ("torch_meta_nan", True),
154 ("torch_meta_nan", False),
155 ]
156 )
159@pytest.mark.parametrize(
160 "options", option_dicts
161) # , ids=lambda opt: f'opts_{opt["fmt"]}_spark{opt["sparkline"]}_col{opt["colored"]}')
162@pytest.mark.parametrize(
163 "input_type,tensor_requires_grad", input_params
164) # , ids=lambda p: f'{p[0]}_grad{p[1]}')
165def test_array_summary_comprehensive(
166 input_type: str, tensor_requires_grad: bool, options: dict
167) -> None:
168 """
169 Comprehensive test for array_summary.
171 This test uses a wide range of parameter combinations for both numpy and torch
172 inputs (including with/without NaNs and empty arrays) and a set of option dictionaries.
173 The resulting summary string (or list of strings) is written to two output files along with an explanation.
174 The file content is then checked for expected substrings based on the input type and option settings.
175 """
176 # Generate the input.
177 arr: Any = generate_input(input_type, tensor_requires_grad)
179 # Call array_summary with the options.
180 summary: Any = array_summary( # type: ignore[call-overload]
181 arr,
182 fmt=options["fmt"],
183 sparkline=options["sparkline"],
184 colored=options["colored"],
185 as_list=options["as_list"],
186 stats=options["stats"],
187 shape=options["shape"],
188 dtype=options["dtype"],
189 device=options["device"],
190 requires_grad=options["call_requires_grad"],
191 eq_char=options["eq_char"],
192 sparkline_bins=options["sparkline_bins"],
193 sparkline_logy=options["sparkline_logy"],
194 )
196 print(f"{arr = }")
197 print(f"{options = }")
198 print(f"{summary = }")
200 # If as_list is True, join to a string for checking.
201 summary_str: str = summary if isinstance(summary, str) else " ".join(summary)
203 # Write explanation and summary to output files.
204 output_dir = TEMP_PATH / "tensor_info"
205 output_dir.mkdir(parents=True, exist_ok=True)
206 tex_file = output_dir / "tensor_info_outputs.tex"
207 txt_file = output_dir / "tensor_info_outputs.txt"
208 explanation: str = f"Test: {input_type} with tensor_requires_grad={tensor_requires_grad} and options={options}\n"
209 with open(tex_file, "a") as f_tex, open(txt_file, "a") as f_txt:
210 f_tex.write(explanation + summary_str + "\n")
211 f_txt.write(explanation + summary_str + "\n")
213 # --- Now perform our assertions ---
214 # If the input is empty, the summary should mention "empty array".
215 if (hasattr(arr, "size") and arr.size == 0) or (
216 torch is not None and isinstance(arr, torch.Tensor) and arr.numel() == 0
217 ):
218 assert (
219 "empty" in summary_str.lower()
220 ), f"Expected 'empty' in summary for {input_type}"
221 else:
222 # For non-empty arrays, if the options ask for dtype and shape info, they should appear.
223 if options["dtype"]:
224 assert (
225 "dtype" in summary_str.lower()
226 ), f"Expected 'dtype' info in summary for {input_type}"
227 if options["shape"]:
228 assert (
229 "shape" in summary_str.lower()
230 ), f"Expected 'shape' info in summary for {input_type}"
231 # For torch inputs with device info requested.
232 if torch is not None and isinstance(arr, torch.Tensor) and options["device"]:
233 assert (
234 "device" in summary_str.lower()
235 ), f"Expected 'device' info in summary for {input_type}"
236 # For arrays with NaNs, if not empty, a warning should be present.
237 if input_type.endswith("with_nan"):
238 assert (
239 "nan" in summary_str.lower() or "nAN" in summary_str
240 ), f"Expected NaN warning in summary for {input_type}"
241 # Check that the equality character appears in the summary (or its an error)
242 assert (
243 options["eq_char"] in summary_str
244 or tensor_info.SYMBOLS[options["fmt"]]["warning"] in summary_str
245 ), f"Expected eq_char '{options['eq_char']}' in summary for {input_type}"
246 # If stats are enabled, at least one statistic (e.g. mean) should appear.
247 if options["stats"] and not (
248 ("empty" in summary_str.lower()) or ("all nan" in summary_str.lower())
249 ):
250 # The symbol for mean depends on format.
251 if options["fmt"] == "unicode":
252 assert (
253 "μ" in summary_str
254 ), f"Expected unicode mean symbol in summary for {input_type}"
255 elif options["fmt"] == "latex":
256 assert (
257 r"\mu" in summary_str
258 ), f"Expected latex mean symbol in summary for {input_type}"
259 elif options["fmt"] == "ascii":
260 assert (
261 "mean" in summary_str
262 ), f"Expected ascii 'mean' in summary for {input_type}"
263 # If sparkline is enabled and the input is non-empty (and not all NaN) then a sparkline should appear.
264 if (
265 options["sparkline"]
266 and summary_str
267 and not (
268 ("empty" in summary_str.lower()) or ("all nan" in summary_str.lower())
269 )
270 ):
271 # We expect at least one vertical bar or sparkline characters.
272 assert "|" in summary_str or any(
273 c in summary_str for c in "▁▂▃▄▅▆▇█_-~=#"
274 ), f"Expected sparkline in summary for {input_type}"
277@pytest.mark.parametrize(
278 "bad_input",
279 [
280 42,
281 "not an array",
282 {"a": 1},
283 None,
284 ],
285)
286def test_array_summary_failure(bad_input: Any) -> None:
287 """
288 Test that array_summary returns a summary indicating failure or empty array
289 when given non–array inputs.
290 """
291 summary = array_summary(bad_input)
292 summary_str: str = summary if isinstance(summary, str) else " ".join(summary)
293 assert summary_str
296def test_generate_sparkline_basic() -> None:
297 """
298 Test the sparkline generator with a fixed histogram.
299 """
300 histogram: np.ndarray = np.array([1, 3, 5, 2, 0])
301 spark: str = generate_sparkline(histogram, format="unicode", log_y=False)
302 assert isinstance(spark, str)
303 assert len(spark) == len(histogram)
304 # Test with log_y=True
305 spark_log: str = generate_sparkline(histogram, format="ascii", log_y=True)
306 assert isinstance(spark_log, str)
307 assert len(spark_log) == len(histogram)
310if __name__ == "__main__":
311 import sys
313 sys.exit(pytest.main(["--maxfail=1", "--disable-warnings", "-q"]))