Coverage for muutils/statcounter.py: 64%
89 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""`StatCounter` class for counting and calculating statistics on numbers
3cleaner and more efficient than just using a `Counter` or array"""
5from __future__ import annotations
7import json
8import math
9from collections import Counter
10from functools import cached_property
11from itertools import chain
12from typing import Callable, Optional, Sequence, Union
15# _GeneralArray = Union[np.ndarray, "torch.Tensor"]
16NumericSequence = Sequence[Union[float, int, "NumericSequence"]]
18# pylint: disable=abstract-method
20# misc
21# ==================================================
24def universal_flatten(
25 arr: Union[NumericSequence, float, int], require_rectangular: bool = True
26) -> NumericSequence:
27 """flattens any iterable"""
29 # mypy complains that the sequence has no attribute "flatten"
30 if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore
31 return arr.flatten() # type: ignore
32 elif isinstance(arr, Sequence):
33 elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr]
34 if require_rectangular and (all(elements_iterable) != any(elements_iterable)):
35 raise ValueError("arr contains mixed iterable and non-iterable elements")
36 if any(elements_iterable):
37 return list(chain.from_iterable(universal_flatten(x) for x in arr)) # type: ignore[misc]
38 else:
39 return arr
40 else:
41 return [arr]
44# StatCounter
45# ==================================================
48class StatCounter(Counter):
49 """`Counter`, but with some stat calculation methods which assume the keys are numerical
51 works best when the keys are `int`s
52 """
54 def validate(self) -> bool:
55 """validate the counter as being all floats or ints"""
56 return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys())
58 def min(self):
59 "minimum value"
60 return min(x for x, v in self.items() if v > 0)
62 def max(self):
63 "maximum value"
64 return max(x for x, v in self.items() if v > 0)
66 def total(self):
67 """Sum of the counts"""
68 return sum(self.values())
70 @cached_property
71 def keys_sorted(self) -> list:
72 """return the keys"""
73 return sorted(list(self.keys()))
75 def percentile(self, p: float):
76 """return the value at the given percentile
78 this could be log time if we did binary search, but that would be a lot of added complexity
79 """
81 if p < 0 or p > 1:
82 raise ValueError(f"percentile must be between 0 and 1: {p}")
83 # flip for speed
84 sorted_keys: list[float] = [float(x) for x in self.keys_sorted]
85 sort: int = 1
86 if p > 0.51:
87 sort = -1
88 p = 1 - p
90 sorted_keys = sorted_keys[::sort]
91 real_target: float = p * (self.total() - 1)
93 n_target_f: int = math.floor(real_target)
94 n_target_c: int = math.ceil(real_target)
96 n_sofar: float = -1
98 # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }')
100 for i, k in enumerate(sorted_keys):
101 n_sofar += self[k]
103 # print(f'{k = } {n_sofar = }')
105 if n_sofar > n_target_f:
106 return k
108 elif n_sofar == n_target_f:
109 if n_sofar == n_target_c:
110 return k
111 else:
112 # print(
113 # sorted_keys[i], (n_sofar + 1 - real_target),
114 # sorted_keys[i + 1], (real_target - n_sofar),
115 # )
116 return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[
117 i + 1
118 ] * (real_target - n_sofar)
119 else:
120 continue
122 raise ValueError(f"percentile {p} not found???")
124 def median(self) -> float:
125 return self.percentile(0.5)
127 def mean(self) -> float:
128 """return the mean of the values"""
129 return float(sum(k * c for k, c in self.items()) / self.total())
131 def mode(self) -> float:
132 return self.most_common()[0][0]
134 def std(self) -> float:
135 """return the standard deviation of the values"""
136 mean: float = self.mean()
137 deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items())
139 return (deviations / self.total()) ** 0.5
141 def summary(
142 self,
143 typecast: Callable = lambda x: x,
144 *,
145 extra_percentiles: Optional[list[float]] = None,
146 ) -> dict[str, Union[float, int]]:
147 """return a summary of the stats, without the raw data. human readable and small"""
148 # common stats that always work
149 output: dict = dict(
150 total_items=self.total(),
151 n_keys=len(self.keys()),
152 mode=self.mode(),
153 )
155 if self.total() > 0:
156 if self.validate():
157 # if its a numeric counter, we can do some stats
158 output = {
159 **output,
160 **dict(
161 mean=float(self.mean()),
162 std=float(self.std()),
163 min=typecast(self.min()),
164 q1=typecast(self.percentile(0.25)),
165 median=typecast(self.median()),
166 q3=typecast(self.percentile(0.75)),
167 max=typecast(self.max()),
168 ),
169 }
171 if extra_percentiles is not None:
172 for p in extra_percentiles:
173 output[f"percentile_{p}"] = typecast(self.percentile(p))
174 else:
175 # if its not, we can only do the simpler things
176 # mean mode and total are done in the initial declaration of `output`
177 pass
179 return output
181 def serialize(
182 self,
183 typecast: Callable = lambda x: x,
184 *,
185 extra_percentiles: Optional[list[float]] = None,
186 ) -> dict:
187 """return a json-serializable version of the counter
189 includes both the output of `summary` and the raw data:
191 ```json
192 {
193 "StatCounter": { <keys, values from raw data> },
194 "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
195 }
197 """
199 return {
200 "StatCounter": {
201 typecast(k): v
202 for k, v in sorted(dict(self).items(), key=lambda x: x[0])
203 },
204 "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
205 }
207 def __str__(self) -> str:
208 "summary as json with 2 space indent, good for printing"
209 return json.dumps(self.summary(), indent=2)
211 def __repr__(self) -> str:
212 return json.dumps(self.serialize(), indent=2)
214 @classmethod
215 def load(cls, data: dict) -> "StatCounter":
216 "load from a the output of `StatCounter.serialize`"
217 if "StatCounter" in data:
218 loadme = data["StatCounter"]
219 else:
220 loadme = data
222 return cls({float(k): v for k, v in loadme.items()})
224 @classmethod
225 def from_list_arrays(
226 cls,
227 arr,
228 map_func: Callable = float,
229 ) -> "StatCounter":
230 """calls `map_func` on each element of `universal_flatten(arr)`"""
231 return cls([map_func(x) for x in universal_flatten(arr)])