Coverage for tests / unit / test_statcounter.py: 100%
15 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
1from __future__ import annotations
3import numpy as np
5from muutils.statcounter import StatCounter
8def _compute_err(a: float, b: float | np.floating, /) -> dict[str, int | float]:
9 result: dict[str, int | float] = dict( # type: ignore[invalid-assignment]
10 num_a=float(a),
11 num_b=float(b),
12 diff=float(b - a),
13 # frac_err=float((b - a) / a), # this causes division by zero, whatever
14 )
15 return result
18def _compare_np_custom(arr: np.ndarray) -> dict[str, dict[str, float]]:
19 counter: StatCounter = StatCounter(arr)
20 return dict(
21 mean=_compute_err(counter.mean(), np.mean(arr)),
22 std=_compute_err(counter.std(), np.std(arr)),
23 min=_compute_err(counter.min(), np.min(arr)), # pyright: ignore[reportUnknownArgumentType, reportAny]
24 q1=_compute_err(counter.percentile(0.25), np.percentile(arr, 25)),
25 median=_compute_err(counter.median(), np.median(arr)),
26 q3=_compute_err(counter.percentile(0.75), np.percentile(arr, 75)),
27 max=_compute_err(counter.max(), np.max(arr)), # pyright: ignore[reportUnknownArgumentType, reportAny]
28 )
31EPSILON: float = 1e-8
34def test_statcounter() -> None:
35 arrs: list[np.ndarray] = [
36 np.array([0, 1]),
37 np.array([1, 2]),
38 np.random.randint(0, 10, size=10),
39 np.random.randint(-5, 15, size=10),
40 np.array([-5, -4, -1, 1, 1, 3, 3, 5, 11, 12]),
41 np.random.randint(-5, 15, size=100),
42 np.random.randint(0, 100, size=100),
43 np.random.randint(0, 1000, size=100),
44 ]
46 # for i, j in np.random.randint(1, 100, size=(50, 2)):
47 # if i > j:
48 # i, j = j, i
50 # arrs.append(np.random.randint(i, j, size=1000))
52 for a in arrs:
53 r = _compare_np_custom(a)
55 assert all([x["diff"] < EPSILON for x in r.values()]), (
56 f"errs for rantint array: {a.shape = } {np.min(a) = } {np.max(a) = } data = {r}"
57 )
58 # s = StatCounter(a)
59 # print(s.total(), s)
60 # print(sorted(list(s.elements())))