Coverage for tests/unit/test_tensor_info.py: 94%
35 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
1from __future__ import annotations
3from typing import Any
5import numpy as np
6import pytest
8from muutils.tensor_info import array_summary, generate_sparkline
11@pytest.mark.parametrize(
12 "bad_input",
13 [
14 42,
15 "not an array",
16 {"a": 1},
17 None,
18 ],
19)
20def test_array_summary_failure(bad_input: Any) -> None:
21 """
22 Test that array_summary returns a summary indicating failure or empty array
23 when given non-array inputs.
24 """
25 summary = array_summary(bad_input) # type: ignore[call-overload]
26 summary_str: str = summary if isinstance(summary, str) else " ".join(summary)
27 assert summary_str
30def test_generate_sparkline_basic() -> None:
31 """
32 Test the sparkline generator with a fixed histogram.
33 """
34 histogram: np.ndarray = np.array([1, 3, 5, 2, 0])
35 spark, logy = generate_sparkline(histogram, format="unicode", log_y=None)
36 assert isinstance(spark, str)
37 assert isinstance(logy, bool)
38 assert not logy
39 assert len(spark) == len(histogram)
40 # Test with log_y=True
41 spark_log, logy_true = generate_sparkline(histogram, format="ascii", log_y=True)
42 assert isinstance(spark_log, str)
43 assert len(spark_log) == len(histogram)
44 assert logy_true
47def test_generate_sparkline_logy() -> None:
48 """
49 Test the sparkline generator with a fixed histogram.
50 """
51 histogram: np.ndarray = np.array([99999, 3, 5, 2, 0])
52 spark, logy = generate_sparkline(histogram, format="unicode", log_y=None)
53 assert isinstance(spark, str)
54 assert isinstance(logy, bool)
55 assert logy
56 assert len(spark) == len(histogram)
57 # Test with log_y=True
58 spark_log, logy_true = generate_sparkline(histogram, format="ascii", log_y=True)
59 assert isinstance(spark_log, str)
60 assert len(spark_log) == len(histogram)
61 assert logy_true
64if __name__ == "__main__":
65 import sys
67 sys.exit(pytest.main(["--maxfail=1", "--disable-warnings", "-q"]))