Coverage for muutils / math / bins.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from functools import cached_property 

5from typing import Literal 

6 

7import numpy as np 

8from jaxtyping import Float 

9 

10 

11@dataclass(frozen=True) 

12class Bins: 

13 n_bins: int = 32 

14 start: float = 0 

15 stop: float = 1.0 

16 scale: Literal["lin", "log"] = "log" 

17 

18 _log_min: float = 1e-3 

19 _zero_in_small_start_log: bool = True 

20 

21 @cached_property 

22 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 

23 if self.scale == "lin": 

24 return np.linspace(self.start, self.stop, self.n_bins + 1) 

25 elif self.scale == "log": 

26 if self.start < 0: 

27 raise ValueError( 

28 f"start must be positive for log scale, got {self.start}" 

29 ) 

30 if self.start == 0: 

31 return np.concatenate( 

32 [ # pyright: ignore[reportUnknownArgumentType] 

33 np.array([0]), 

34 np.logspace( 

35 np.log10(self._log_min), # pyright: ignore[reportAny] 

36 np.log10(self.stop), # pyright: ignore[reportAny] 

37 self.n_bins, 

38 ), 

39 ] 

40 ) 

41 elif self.start < self._log_min and self._zero_in_small_start_log: 

42 return np.concatenate( 

43 [ # pyright: ignore[reportUnknownArgumentType] 

44 np.array([0]), 

45 np.logspace( 

46 np.log10(self.start), # pyright: ignore[reportAny] 

47 np.log10(self.stop), # pyright: ignore[reportAny] 

48 self.n_bins, 

49 ), 

50 ] 

51 ) 

52 else: 

53 return np.logspace( # pyright: ignore[reportUnknownVariableType] 

54 np.log10(self.start), # pyright: ignore[reportAny] 

55 np.log10(self.stop), # pyright: ignore[reportAny] 

56 self.n_bins + 1, 

57 ) 

58 else: 

59 raise ValueError(f"Invalid scale {self.scale}, expected lin or log") 

60 

61 @cached_property 

62 def centers(self) -> Float[np.ndarray, "n_bins"]: 

63 return (self.edges[:-1] + self.edges[1:]) / 2 

64 

65 def changed_n_bins_copy(self, n_bins: int) -> "Bins": 

66 return Bins( 

67 n_bins=n_bins, 

68 start=self.start, 

69 stop=self.stop, 

70 scale=self.scale, 

71 _log_min=self._log_min, 

72 _zero_in_small_start_log=self._zero_in_small_start_log, 

73 )