Coverage for tests / unit / test_statcounter.py: 100%

15 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1from __future__ import annotations 

2 

3import numpy as np 

4 

5from muutils.statcounter import StatCounter 

6 

7 

8def _compute_err(a: float, b: float | np.floating, /) -> dict[str, int | float]: 

9 result: dict[str, int | float] = dict( # type: ignore[invalid-assignment] 

10 num_a=float(a), 

11 num_b=float(b), 

12 diff=float(b - a), 

13 # frac_err=float((b - a) / a), # this causes division by zero, whatever 

14 ) 

15 return result 

16 

17 

18def _compare_np_custom(arr: np.ndarray) -> dict[str, dict[str, float]]: 

19 counter: StatCounter = StatCounter(arr) 

20 return dict( 

21 mean=_compute_err(counter.mean(), np.mean(arr)), 

22 std=_compute_err(counter.std(), np.std(arr)), 

23 min=_compute_err(counter.min(), np.min(arr)), # pyright: ignore[reportUnknownArgumentType, reportAny] 

24 q1=_compute_err(counter.percentile(0.25), np.percentile(arr, 25)), 

25 median=_compute_err(counter.median(), np.median(arr)), 

26 q3=_compute_err(counter.percentile(0.75), np.percentile(arr, 75)), 

27 max=_compute_err(counter.max(), np.max(arr)), # pyright: ignore[reportUnknownArgumentType, reportAny] 

28 ) 

29 

30 

31EPSILON: float = 1e-8 

32 

33 

34def test_statcounter() -> None: 

35 arrs: list[np.ndarray] = [ 

36 np.array([0, 1]), 

37 np.array([1, 2]), 

38 np.random.randint(0, 10, size=10), 

39 np.random.randint(-5, 15, size=10), 

40 np.array([-5, -4, -1, 1, 1, 3, 3, 5, 11, 12]), 

41 np.random.randint(-5, 15, size=100), 

42 np.random.randint(0, 100, size=100), 

43 np.random.randint(0, 1000, size=100), 

44 ] 

45 

46 # for i, j in np.random.randint(1, 100, size=(50, 2)): 

47 # if i > j: 

48 # i, j = j, i 

49 

50 # arrs.append(np.random.randint(i, j, size=1000)) 

51 

52 for a in arrs: 

53 r = _compare_np_custom(a) 

54 

55 assert all([x["diff"] < EPSILON for x in r.values()]), ( 

56 f"errs for rantint array: {a.shape = } {np.min(a) = } {np.max(a) = } data = {r}" 

57 ) 

58 # s = StatCounter(a) 

59 # print(s.total(), s) 

60 # print(sorted(list(s.elements())))