Coverage for muutils/math/matrix_powers.py: 96%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-30 22:10 -0600

1from __future__ import annotations 

2 

3from typing import List, Sequence, TYPE_CHECKING 

4 

5import numpy as np 

6from jaxtyping import Float, Int 

7 

8if TYPE_CHECKING: 

9 pass 

10 

11 

12def matrix_powers( 

13 A: Float[np.ndarray, "n n"], 

14 powers: Sequence[int], 

15) -> Float[np.ndarray, "n_powers n n"]: 

16 """Compute multiple powers of a matrix efficiently. 

17 

18 Uses binary exponentiation to compute powers in O(log max(powers)) 

19 matrix multiplications, avoiding redundant calculations when 

20 computing multiple powers. 

21 

22 # Parameters: 

23 - `A : Float[np.ndarray, "n n"]` 

24 Square matrix to exponentiate 

25 - `powers : Sequence[int]` 

26 List of powers to compute (non-negative integers) 

27 

28 # Returns: 

29 - `dict[int, Float[np.ndarray, "n n"]]` 

30 Dictionary mapping each requested power to the corresponding matrix power 

31 """ 

32 dim_n: int = A.shape[0] 

33 assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }" 

34 powers_np: Int[np.ndarray, "n_powers_unique"] = np.array( 

35 sorted(set(powers)), dtype=int 

36 ) 

37 n_powers_unique: int = len(powers_np) 

38 

39 if n_powers_unique < 1: 

40 raise ValueError(f"No powers requested: {powers = }") 

41 

42 output: Float[np.ndarray, "n_powers_unique n n"] = np.full( 

43 (n_powers_unique, dim_n, dim_n), 

44 fill_value=np.nan, 

45 dtype=A.dtype, 

46 ) 

47 

48 # Find the maximum power to compute 

49 max_power: int = max(powers_np) 

50 

51 # Precompute all powers of 2 up to the largest power needed 

52 # This forms our basis for binary decomposition 

53 powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {} 

54 powers_of_two[0] = np.eye(dim_n, dtype=A.dtype) 

55 powers_of_two[1] = A.copy() 

56 

57 # Compute powers of 2: A^2, A^4, A^8, ... 

58 p: int = 1 

59 while p < max_power: 

60 if p <= max_power: 

61 A_power_p = powers_of_two[p] 

62 powers_of_two[p * 2] = A_power_p @ A_power_p 

63 p = p * 2 

64 

65 # For each requested power, compute it using the powers of 2 

66 for p_idx, power in enumerate(powers_np): 

67 # Decompose power into sum of powers of 2 

68 temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy() 

69 temp_power: int = power 

70 p_temp: int = 1 

71 

72 while temp_power > 0: 

73 if temp_power % 2 == 1: 

74 temp_result = temp_result @ powers_of_two[p_temp] 

75 temp_power = temp_power // 2 

76 p_temp *= 2 

77 

78 output[p_idx] = temp_result 

79 

80 return output 

81 

82 

83# BUG: breaks with integer matrices??? 

84# TYPING: jaxtyping hints not working here, separate file for torch implementation? 

85def matrix_powers_torch( 

86 A, # : Float["torch.Tensor", "n n"], 

87 powers: Sequence[int], 

88): # Float["torch.Tensor", "n_powers n n"]: 

89 """Compute multiple powers of a matrix efficiently. 

90 

91 Uses binary exponentiation to compute powers in O(log max(powers)) 

92 matrix multiplications, avoiding redundant calculations when 

93 computing multiple powers. 

94 

95 # Parameters: 

96 - `A : Float[torch.Tensor, "n n"]` 

97 Square matrix to exponentiate 

98 - `powers : Sequence[int]` 

99 List of powers to compute (non-negative integers) 

100 

101 # Returns: 

102 - `Float[torch.Tensor, "n_powers n n"]` 

103 Tensor containing the requested matrix powers stacked along the first dimension 

104 

105 # Raises: 

106 - `ValueError` : If no powers are requested or if A is not a square matrix 

107 """ 

108 

109 import torch 

110 

111 if len(A.shape) != 2 or A.shape[0] != A.shape[1]: 

112 raise ValueError(f"Matrix must be square, but got {A.shape = }") 

113 

114 dim_n: int = A.shape[0] 

115 # Get unique powers and sort them 

116 unique_powers: List[int] = sorted(set(powers)) 

117 n_powers_unique: int = len(unique_powers) 

118 powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor( 

119 unique_powers, dtype=torch.int64, device=A.device 

120 ) 

121 

122 if n_powers_unique < 1: 

123 raise ValueError(f"No powers requested: {powers = }") 

124 

125 output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full( 

126 (n_powers_unique, dim_n, dim_n), 

127 float("nan"), 

128 dtype=A.dtype, 

129 device=A.device, 

130 ) 

131 

132 # Find the maximum power to compute 

133 max_power: int = int(powers_tensor.max().item()) 

134 

135 # Precompute all powers of 2 up to the largest power needed 

136 # This forms our basis for binary decomposition 

137 powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {} 

138 powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device) 

139 powers_of_two[1] = A.clone() 

140 

141 # Compute powers of 2: A^2, A^4, A^8, ... 

142 p: int = 1 

143 while p < max_power: 

144 if p <= max_power: 

145 A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p] 

146 powers_of_two[p * 2] = A_power_p @ A_power_p 

147 p = p * 2 

148 

149 # For each requested power, compute it using the powers of 2 

150 for p_idx, power in enumerate(unique_powers): 

151 # Decompose power into sum of powers of 2 

152 temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone() 

153 temp_power: int = power 

154 p_temp: int = 1 

155 

156 while temp_power > 0: 

157 if temp_power % 2 == 1: 

158 temp_result = temp_result @ powers_of_two[p_temp] 

159 temp_power = temp_power // 2 

160 p_temp *= 2 

161 

162 output[p_idx] = temp_result 

163 

164 return output