Coverage for tests/unit/math/test_matrix_powers.py: 100%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-30 22:10 -0600

1from __future__ import annotations 

2 

3from typing import List, Tuple 

4import numpy as np 

5import pytest 

6from jaxtyping import Float 

7import torch 

8 

9from muutils.dbg import dbg_tensor 

10from muutils.math.matrix_powers import matrix_powers, matrix_powers_torch 

11 

12 

13class TestMatrixPowers: 

14 @pytest.fixture 

15 def sample_matrices(self) -> List[Tuple[str, Float[np.ndarray, "n n"]]]: 

16 """Return a list of test matrices with diverse properties.""" 

17 return [ 

18 ("identity", np.eye(3)), 

19 ("diagonal", np.diag([2.0, 3, 4])), 

20 ("nilpotent", np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0.0]])), 

21 # ("random_int", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 

22 ("random_float", np.random.rand(4, 4)), 

23 ("complex", np.array([[1 + 1j, 2], [3, 4 - 2j]])), 

24 ] 

25 

26 @pytest.fixture 

27 def power_test_cases(self) -> List[List[int]]: 

28 """Return test cases for powers to compute.""" 

29 return [ 

30 [0], 

31 [1], 

32 [2], 

33 [0, 1, 2], 

34 [5], 

35 [0, 1, 5, 10], 

36 [1, 2, 4, 8, 16], 

37 list(range(10)), 

38 ] 

39 

40 def test_against_numpy_implementation( 

41 self, 

42 sample_matrices: List[Tuple[str, Float[np.ndarray, "n n"]]], 

43 power_test_cases: List[List[int]], 

44 ) -> None: 

45 """Test that matrix_powers gives same results as numpy.linalg.matrix_power.""" 

46 for name, matrix in sample_matrices: 

47 for powers in power_test_cases: 

48 # Compute with our function 

49 dbg_tensor(matrix) 

50 result = matrix_powers(matrix, powers) 

51 result_torch = ( 

52 matrix_powers_torch(torch.tensor(matrix), powers).cpu().numpy() 

53 ) 

54 

55 # Get dimension information 

56 n_powers = len(set(powers)) 

57 dim_n = matrix.shape[0] 

58 

59 # Check shape of output 

60 assert result.shape == ( 

61 n_powers, 

62 dim_n, 

63 dim_n, 

64 ), f"Incorrect shape for {name} matrix with powers {powers}" 

65 assert result_torch.shape == ( 

66 n_powers, 

67 dim_n, 

68 dim_n, 

69 ), f"Incorrect shape for {name} matrix with powers {powers} (torch)" 

70 

71 # Compare with numpy implementation for each power 

72 unique_powers = sorted(set(powers)) 

73 for i, power in enumerate(unique_powers): 

74 expected = np.linalg.matrix_power(matrix, power) 

75 np.testing.assert_allclose( 

76 result[i], 

77 expected, 

78 rtol=1e-10, 

79 atol=1e-10, 

80 err_msg=f"Failed for {name} matrix to power {power}", 

81 ) 

82 np.testing.assert_allclose( 

83 result_torch[i], 

84 expected, 

85 rtol=1e-10, 

86 atol=1e-10, 

87 err_msg=f"Failed for {name} matrix to power {power}", 

88 ) 

89 

90 def test_empty_powers_list(self) -> None: 

91 """Test handling of empty powers list.""" 

92 A = np.eye(3) 

93 with pytest.raises(ValueError): 

94 matrix_powers(A, []) 

95 

96 def test_duplicate_powers(self) -> None: 

97 """Test handling of duplicate powers in the input list.""" 

98 A = np.diag([2.0, 3, 4]) 

99 powers = [1, 2, 2, 3, 1] 

100 result = matrix_powers(A, powers) 

101 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy() 

102 

103 # Should only have 3 unique powers 

104 assert result.shape == (3, 3, 3) 

105 

106 # Check each power 

107 unique_powers = sorted(set(powers)) 

108 for i, power in enumerate(unique_powers): 

109 expected = np.linalg.matrix_power(A, power) 

110 np.testing.assert_allclose(result[i], expected) 

111 np.testing.assert_allclose(result_torch[i], expected) 

112 

113 def test_non_square_matrix(self) -> None: 

114 """Test that an assertion error is raised for non-square matrices.""" 

115 A = np.ones((3, 4)) 

116 with pytest.raises(AssertionError): 

117 matrix_powers(A, [1, 2]) 

118 

119 def test_negative_powers(self) -> None: 

120 """Test handling of negative powers (should work with invertible matrices).""" 

121 # Use an invertible matrix 

122 A = np.array([[1, 2], [3, 4]]) 

123 powers = [-1, -2, 0, 1, 2] 

124 

125 # This might raise an error if negative powers aren't supported 

126 try: 

127 result = matrix_powers(A, powers) 

128 

129 # If it succeeds, verify the results 

130 unique_powers = sorted(set(powers)) 

131 for i, power in enumerate(unique_powers): 

132 expected = np.linalg.matrix_power(A, power) 

133 np.testing.assert_allclose(result[i], expected) 

134 except Exception as e: 

135 pytest.skip(f"Negative powers not supported: {e}") 

136 

137 def test_large_powers(self) -> None: 

138 """Test with large powers to verify binary exponentiation efficiency.""" 

139 # Matrix with eigenvalues < 1 to avoid overflow 

140 A = np.array([[0.5, 0.1], [0.1, 0.5]]) 

141 large_power = 1000 

142 powers = [large_power] 

143 

144 result = matrix_powers(A, powers) 

145 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy() 

146 expected = np.linalg.matrix_power(A, large_power) 

147 

148 np.testing.assert_allclose(result[0], expected) 

149 np.testing.assert_allclose(result_torch[0], expected) 

150 

151 def test_performance(self) -> None: 

152 """Test that binary exponentiation is more efficient than naive approach.""" 

153 import time 

154 

155 A = np.random.randn(64, 64) 

156 powers = [10, 100] + list(range(1000, 1024)) + list(range(9_000, 9_200)) 

157 

158 # Time our implementation 

159 start = time.time() 

160 p_np = matrix_powers(A, powers) 

161 our_time = time.time() - start 

162 

163 # Time torch implementation 

164 start = time.time() 

165 p_torch = matrix_powers_torch(torch.tensor(A), powers) 

166 torch_time = time.time() - start 

167 

168 # Time naive approach 

169 start = time.time() 

170 p_naive = [] 

171 for power in powers: 

172 p_naive.append(np.linalg.matrix_power(A, power)) 

173 naive_time = time.time() - start 

174 

175 assert len(p_np) == len(p_naive), "Output lengths do not match" 

176 assert len(p_torch) == len(p_naive), "Torch output lengths do not match" 

177 

178 # Our implementation should be faster for these powers 

179 assert ( 

180 our_time < naive_time 

181 ), f"Binary exponentiation ({our_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)" 

182 assert ( 

183 torch_time < naive_time 

184 ), f"Binary exponentiation ({torch_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)"