Coverage for muutils / math / bins.py: 100%
32 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1from __future__ import annotations
3from dataclasses import dataclass
4from functools import cached_property
5from typing import Literal
7import numpy as np
8from jaxtyping import Float
11@dataclass(frozen=True)
12class Bins:
13 n_bins: int = 32
14 start: float = 0
15 stop: float = 1.0
16 scale: Literal["lin", "log"] = "log"
18 _log_min: float = 1e-3
19 _zero_in_small_start_log: bool = True
21 @cached_property
22 def edges(self) -> Float[np.ndarray, "n_bins+1"]:
23 if self.scale == "lin":
24 return np.linspace(self.start, self.stop, self.n_bins + 1)
25 elif self.scale == "log":
26 if self.start < 0:
27 raise ValueError(
28 f"start must be positive for log scale, got {self.start}"
29 )
30 if self.start == 0:
31 return np.concatenate(
32 [ # pyright: ignore[reportUnknownArgumentType]
33 np.array([0]),
34 np.logspace(
35 np.log10(self._log_min), # pyright: ignore[reportAny]
36 np.log10(self.stop), # pyright: ignore[reportAny]
37 self.n_bins,
38 ),
39 ]
40 )
41 elif self.start < self._log_min and self._zero_in_small_start_log:
42 return np.concatenate(
43 [ # pyright: ignore[reportUnknownArgumentType]
44 np.array([0]),
45 np.logspace(
46 np.log10(self.start), # pyright: ignore[reportAny]
47 np.log10(self.stop), # pyright: ignore[reportAny]
48 self.n_bins,
49 ),
50 ]
51 )
52 else:
53 return np.logspace( # pyright: ignore[reportUnknownVariableType]
54 np.log10(self.start), # pyright: ignore[reportAny]
55 np.log10(self.stop), # pyright: ignore[reportAny]
56 self.n_bins + 1,
57 )
58 else:
59 raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
61 @cached_property
62 def centers(self) -> Float[np.ndarray, "n_bins"]:
63 return (self.edges[:-1] + self.edges[1:]) / 2
65 def changed_n_bins_copy(self, n_bins: int) -> "Bins":
66 return Bins(
67 n_bins=n_bins,
68 start=self.start,
69 stop=self.stop,
70 scale=self.scale,
71 _log_min=self._log_min,
72 _zero_in_small_start_log=self._zero_in_small_start_log,
73 )