Coverage for tests/unit/test_tensor_info.py: 94%

35 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-28 17:24 +0000

1from __future__ import annotations 

2 

3from typing import Any 

4 

5import numpy as np 

6import pytest 

7 

8from muutils.tensor_info import array_summary, generate_sparkline 

9 

10 

11@pytest.mark.parametrize( 

12 "bad_input", 

13 [ 

14 42, 

15 "not an array", 

16 {"a": 1}, 

17 None, 

18 ], 

19) 

20def test_array_summary_failure(bad_input: Any) -> None: 

21 """ 

22 Test that array_summary returns a summary indicating failure or empty array 

23 when given non-array inputs. 

24 """ 

25 summary = array_summary(bad_input) # type: ignore[call-overload] 

26 summary_str: str = summary if isinstance(summary, str) else " ".join(summary) 

27 assert summary_str 

28 

29 

30def test_generate_sparkline_basic() -> None: 

31 """ 

32 Test the sparkline generator with a fixed histogram. 

33 """ 

34 histogram: np.ndarray = np.array([1, 3, 5, 2, 0]) 

35 spark, logy = generate_sparkline(histogram, format="unicode", log_y=None) 

36 assert isinstance(spark, str) 

37 assert isinstance(logy, bool) 

38 assert not logy 

39 assert len(spark) == len(histogram) 

40 # Test with log_y=True 

41 spark_log, logy_true = generate_sparkline(histogram, format="ascii", log_y=True) 

42 assert isinstance(spark_log, str) 

43 assert len(spark_log) == len(histogram) 

44 assert logy_true 

45 

46 

47def test_generate_sparkline_logy() -> None: 

48 """ 

49 Test the sparkline generator with a fixed histogram. 

50 """ 

51 histogram: np.ndarray = np.array([99999, 3, 5, 2, 0]) 

52 spark, logy = generate_sparkline(histogram, format="unicode", log_y=None) 

53 assert isinstance(spark, str) 

54 assert isinstance(logy, bool) 

55 assert logy 

56 assert len(spark) == len(histogram) 

57 # Test with log_y=True 

58 spark_log, logy_true = generate_sparkline(histogram, format="ascii", log_y=True) 

59 assert isinstance(spark_log, str) 

60 assert len(spark_log) == len(histogram) 

61 assert logy_true 

62 

63 

64if __name__ == "__main__": 

65 import sys 

66 

67 sys.exit(pytest.main(["--maxfail=1", "--disable-warnings", "-q"]))