Coverage for muutils/statcounter.py: 64%

89 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""`StatCounter` class for counting and calculating statistics on numbers 

2 

3cleaner and more efficient than just using a `Counter` or array""" 

4 

5from __future__ import annotations 

6 

7import json 

8import math 

9from collections import Counter 

10from functools import cached_property 

11from itertools import chain 

12from typing import Callable, Optional, Sequence, Union 

13 

14 

15# _GeneralArray = Union[np.ndarray, "torch.Tensor"] 

16NumericSequence = Sequence[Union[float, int, "NumericSequence"]] 

17 

18# pylint: disable=abstract-method 

19 

20# misc 

21# ================================================== 

22 

23 

24def universal_flatten( 

25 arr: Union[NumericSequence, float, int], require_rectangular: bool = True 

26) -> NumericSequence: 

27 """flattens any iterable""" 

28 

29 # mypy complains that the sequence has no attribute "flatten" 

30 if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore 

31 return arr.flatten() # type: ignore 

32 elif isinstance(arr, Sequence): 

33 elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr] 

34 if require_rectangular and (all(elements_iterable) != any(elements_iterable)): 

35 raise ValueError("arr contains mixed iterable and non-iterable elements") 

36 if any(elements_iterable): 

37 return list(chain.from_iterable(universal_flatten(x) for x in arr)) # type: ignore[misc] 

38 else: 

39 return arr 

40 else: 

41 return [arr] 

42 

43 

44# StatCounter 

45# ================================================== 

46 

47 

48class StatCounter(Counter): 

49 """`Counter`, but with some stat calculation methods which assume the keys are numerical 

50 

51 works best when the keys are `int`s 

52 """ 

53 

54 def validate(self) -> bool: 

55 """validate the counter as being all floats or ints""" 

56 return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys()) 

57 

58 def min(self): 

59 "minimum value" 

60 return min(x for x, v in self.items() if v > 0) 

61 

62 def max(self): 

63 "maximum value" 

64 return max(x for x, v in self.items() if v > 0) 

65 

66 def total(self): 

67 """Sum of the counts""" 

68 return sum(self.values()) 

69 

70 @cached_property 

71 def keys_sorted(self) -> list: 

72 """return the keys""" 

73 return sorted(list(self.keys())) 

74 

75 def percentile(self, p: float): 

76 """return the value at the given percentile 

77 

78 this could be log time if we did binary search, but that would be a lot of added complexity 

79 """ 

80 

81 if p < 0 or p > 1: 

82 raise ValueError(f"percentile must be between 0 and 1: {p}") 

83 # flip for speed 

84 sorted_keys: list[float] = [float(x) for x in self.keys_sorted] 

85 sort: int = 1 

86 if p > 0.51: 

87 sort = -1 

88 p = 1 - p 

89 

90 sorted_keys = sorted_keys[::sort] 

91 real_target: float = p * (self.total() - 1) 

92 

93 n_target_f: int = math.floor(real_target) 

94 n_target_c: int = math.ceil(real_target) 

95 

96 n_sofar: float = -1 

97 

98 # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }') 

99 

100 for i, k in enumerate(sorted_keys): 

101 n_sofar += self[k] 

102 

103 # print(f'{k = } {n_sofar = }') 

104 

105 if n_sofar > n_target_f: 

106 return k 

107 

108 elif n_sofar == n_target_f: 

109 if n_sofar == n_target_c: 

110 return k 

111 else: 

112 # print( 

113 # sorted_keys[i], (n_sofar + 1 - real_target), 

114 # sorted_keys[i + 1], (real_target - n_sofar), 

115 # ) 

116 return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[ 

117 i + 1 

118 ] * (real_target - n_sofar) 

119 else: 

120 continue 

121 

122 raise ValueError(f"percentile {p} not found???") 

123 

124 def median(self) -> float: 

125 return self.percentile(0.5) 

126 

127 def mean(self) -> float: 

128 """return the mean of the values""" 

129 return float(sum(k * c for k, c in self.items()) / self.total()) 

130 

131 def mode(self) -> float: 

132 return self.most_common()[0][0] 

133 

134 def std(self) -> float: 

135 """return the standard deviation of the values""" 

136 mean: float = self.mean() 

137 deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items()) 

138 

139 return (deviations / self.total()) ** 0.5 

140 

141 def summary( 

142 self, 

143 typecast: Callable = lambda x: x, 

144 *, 

145 extra_percentiles: Optional[list[float]] = None, 

146 ) -> dict[str, Union[float, int]]: 

147 """return a summary of the stats, without the raw data. human readable and small""" 

148 # common stats that always work 

149 output: dict = dict( 

150 total_items=self.total(), 

151 n_keys=len(self.keys()), 

152 mode=self.mode(), 

153 ) 

154 

155 if self.total() > 0: 

156 if self.validate(): 

157 # if its a numeric counter, we can do some stats 

158 output = { 

159 **output, 

160 **dict( 

161 mean=float(self.mean()), 

162 std=float(self.std()), 

163 min=typecast(self.min()), 

164 q1=typecast(self.percentile(0.25)), 

165 median=typecast(self.median()), 

166 q3=typecast(self.percentile(0.75)), 

167 max=typecast(self.max()), 

168 ), 

169 } 

170 

171 if extra_percentiles is not None: 

172 for p in extra_percentiles: 

173 output[f"percentile_{p}"] = typecast(self.percentile(p)) 

174 else: 

175 # if its not, we can only do the simpler things 

176 # mean mode and total are done in the initial declaration of `output` 

177 pass 

178 

179 return output 

180 

181 def serialize( 

182 self, 

183 typecast: Callable = lambda x: x, 

184 *, 

185 extra_percentiles: Optional[list[float]] = None, 

186 ) -> dict: 

187 """return a json-serializable version of the counter 

188 

189 includes both the output of `summary` and the raw data: 

190 

191 ```json 

192 { 

193 "StatCounter": { <keys, values from raw data> }, 

194 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 

195 } 

196 

197 """ 

198 

199 return { 

200 "StatCounter": { 

201 typecast(k): v 

202 for k, v in sorted(dict(self).items(), key=lambda x: x[0]) 

203 }, 

204 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 

205 } 

206 

207 def __str__(self) -> str: 

208 "summary as json with 2 space indent, good for printing" 

209 return json.dumps(self.summary(), indent=2) 

210 

211 def __repr__(self) -> str: 

212 return json.dumps(self.serialize(), indent=2) 

213 

214 @classmethod 

215 def load(cls, data: dict) -> "StatCounter": 

216 "load from a the output of `StatCounter.serialize`" 

217 if "StatCounter" in data: 

218 loadme = data["StatCounter"] 

219 else: 

220 loadme = data 

221 

222 return cls({float(k): v for k, v in loadme.items()}) 

223 

224 @classmethod 

225 def from_list_arrays( 

226 cls, 

227 arr, 

228 map_func: Callable = float, 

229 ) -> "StatCounter": 

230 """calls `map_func` on each element of `universal_flatten(arr)`""" 

231 return cls([map_func(x) for x in universal_flatten(arr)])