Coverage for muutils/math/bins.py: 100%
32 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-30 22:10 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-30 22:10 -0600
1from __future__ import annotations
3from dataclasses import dataclass
4from functools import cached_property
5from typing import Literal
7import numpy as np
8from jaxtyping import Float
11@dataclass(frozen=True)
12class Bins:
13 n_bins: int = 32
14 start: float = 0
15 stop: float = 1.0
16 scale: Literal["lin", "log"] = "log"
18 _log_min: float = 1e-3
19 _zero_in_small_start_log: bool = True
21 @cached_property
22 def edges(self) -> Float[np.ndarray, "n_bins+1"]:
23 if self.scale == "lin":
24 return np.linspace(self.start, self.stop, self.n_bins + 1)
25 elif self.scale == "log":
26 if self.start < 0:
27 raise ValueError(
28 f"start must be positive for log scale, got {self.start}"
29 )
30 if self.start == 0:
31 return np.concatenate(
32 [
33 np.array([0]),
34 np.logspace(
35 np.log10(self._log_min), np.log10(self.stop), self.n_bins
36 ),
37 ]
38 )
39 elif self.start < self._log_min and self._zero_in_small_start_log:
40 return np.concatenate(
41 [
42 np.array([0]),
43 np.logspace(
44 np.log10(self.start), np.log10(self.stop), self.n_bins
45 ),
46 ]
47 )
48 else:
49 return np.logspace(
50 np.log10(self.start), np.log10(self.stop), self.n_bins + 1
51 )
52 else:
53 raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
55 @cached_property
56 def centers(self) -> Float[np.ndarray, "n_bins"]:
57 return (self.edges[:-1] + self.edges[1:]) / 2
59 def changed_n_bins_copy(self, n_bins: int) -> "Bins":
60 return Bins(
61 n_bins=n_bins,
62 start=self.start,
63 stop=self.stop,
64 scale=self.scale,
65 _log_min=self._log_min,
66 _zero_in_small_start_log=self._zero_in_small_start_log,
67 )