Coverage for muutils/math/bins.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-30 22:10 -0600

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from functools import cached_property 

5from typing import Literal 

6 

7import numpy as np 

8from jaxtyping import Float 

9 

10 

11@dataclass(frozen=True) 

12class Bins: 

13 n_bins: int = 32 

14 start: float = 0 

15 stop: float = 1.0 

16 scale: Literal["lin", "log"] = "log" 

17 

18 _log_min: float = 1e-3 

19 _zero_in_small_start_log: bool = True 

20 

21 @cached_property 

22 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 

23 if self.scale == "lin": 

24 return np.linspace(self.start, self.stop, self.n_bins + 1) 

25 elif self.scale == "log": 

26 if self.start < 0: 

27 raise ValueError( 

28 f"start must be positive for log scale, got {self.start}" 

29 ) 

30 if self.start == 0: 

31 return np.concatenate( 

32 [ 

33 np.array([0]), 

34 np.logspace( 

35 np.log10(self._log_min), np.log10(self.stop), self.n_bins 

36 ), 

37 ] 

38 ) 

39 elif self.start < self._log_min and self._zero_in_small_start_log: 

40 return np.concatenate( 

41 [ 

42 np.array([0]), 

43 np.logspace( 

44 np.log10(self.start), np.log10(self.stop), self.n_bins 

45 ), 

46 ] 

47 ) 

48 else: 

49 return np.logspace( 

50 np.log10(self.start), np.log10(self.stop), self.n_bins + 1 

51 ) 

52 else: 

53 raise ValueError(f"Invalid scale {self.scale}, expected lin or log") 

54 

55 @cached_property 

56 def centers(self) -> Float[np.ndarray, "n_bins"]: 

57 return (self.edges[:-1] + self.edges[1:]) / 2 

58 

59 def changed_n_bins_copy(self, n_bins: int) -> "Bins": 

60 return Bins( 

61 n_bins=n_bins, 

62 start=self.start, 

63 stop=self.stop, 

64 scale=self.scale, 

65 _log_min=self._log_min, 

66 _zero_in_small_start_log=self._zero_in_small_start_log, 

67 )