Coverage for tests/unit/math/test_matrix_powers_torch.py: 98%

91 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-28 17:24 +0000

1from __future__ import annotations 

2 

3from typing import List, Tuple 

4import warnings 

5import numpy as np 

6import pytest 

7from jaxtyping import Float 

8import torch 

9 

10from muutils.dbg import dbg_tensor 

11from muutils.math.matrix_powers import matrix_powers, matrix_powers_torch 

12 

13 

14class TestMatrixPowers: 

15 @pytest.fixture 

16 def sample_matrices(self) -> List[Tuple[str, Float[np.ndarray, "n n"]]]: 

17 """Return a list of test matrices with diverse properties.""" 

18 return [ 

19 ("identity", np.eye(3)), 

20 ("diagonal", np.diag([2.0, 3, 4])), 

21 ("nilpotent", np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0.0]])), 

22 # ("random_int", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 

23 ("random_float", np.random.rand(4, 4)), 

24 ("complex", np.array([[1 + 1j, 2], [3, 4 - 2j]])), 

25 ] 

26 

27 @pytest.fixture 

28 def power_test_cases(self) -> List[List[int]]: 

29 """Return test cases for powers to compute.""" 

30 return [ 

31 [0], 

32 [1], 

33 [2], 

34 [0, 1, 2], 

35 [5], 

36 [0, 1, 5, 10], 

37 [1, 2, 4, 8, 16], 

38 list(range(10)), 

39 ] 

40 

41 def test_against_numpy_implementation( 

42 self, 

43 sample_matrices: List[Tuple[str, Float[np.ndarray, "n n"]]], 

44 power_test_cases: List[List[int]], 

45 ) -> None: 

46 """Test that matrix_powers gives same results as numpy.linalg.matrix_power.""" 

47 for name, matrix in sample_matrices: 

48 for powers in power_test_cases: 

49 # Compute with our function 

50 dbg_tensor(matrix) 

51 result = matrix_powers(matrix, powers) 

52 result_torch = ( 

53 matrix_powers_torch(torch.tensor(matrix), powers).cpu().numpy() 

54 ) 

55 

56 # Get dimension information 

57 n_powers = len(set(powers)) 

58 dim_n = matrix.shape[0] 

59 

60 # Check shape of output 

61 assert result.shape == ( 

62 n_powers, 

63 dim_n, 

64 dim_n, 

65 ), f"Incorrect shape for {name} matrix with powers {powers}" 

66 assert result_torch.shape == ( 

67 n_powers, 

68 dim_n, 

69 dim_n, 

70 ), f"Incorrect shape for {name} matrix with powers {powers} (torch)" 

71 

72 # Compare with numpy implementation for each power 

73 unique_powers = sorted(set(powers)) 

74 for i, power in enumerate(unique_powers): 

75 expected = np.linalg.matrix_power(matrix, power) 

76 np.testing.assert_allclose( 

77 result[i], 

78 expected, 

79 rtol=1e-10, 

80 atol=1e-10, 

81 err_msg=f"Failed for {name} matrix to power {power}", 

82 ) 

83 np.testing.assert_allclose( 

84 result_torch[i], 

85 expected, 

86 rtol=1e-10, 

87 atol=1e-10, 

88 err_msg=f"Failed for {name} matrix to power {power}", 

89 ) 

90 

91 def test_empty_powers_list(self) -> None: 

92 """Test handling of empty powers list.""" 

93 A = np.eye(3) 

94 with pytest.raises(ValueError): 

95 matrix_powers(A, []) 

96 

97 def test_duplicate_powers(self) -> None: 

98 """Test handling of duplicate powers in the input list.""" 

99 A = np.diag([2.0, 3, 4]) 

100 powers = [1, 2, 2, 3, 1] 

101 result = matrix_powers(A, powers) 

102 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy() 

103 

104 # Should only have 3 unique powers 

105 assert result.shape == (3, 3, 3) 

106 

107 # Check each power 

108 unique_powers = sorted(set(powers)) 

109 for i, power in enumerate(unique_powers): 

110 expected = np.linalg.matrix_power(A, power) 

111 np.testing.assert_allclose(result[i], expected) 

112 np.testing.assert_allclose(result_torch[i], expected) 

113 

114 def test_non_square_matrix(self) -> None: 

115 """Test that an assertion error is raised for non-square matrices.""" 

116 A = np.ones((3, 4)) 

117 with pytest.raises(AssertionError): 

118 matrix_powers(A, [1, 2]) 

119 

120 def test_negative_powers(self) -> None: 

121 """Test handling of negative powers (should work with invertible matrices).""" 

122 # Use an invertible matrix 

123 A = np.array([[1, 2], [3, 4]]) 

124 powers = [-1, -2, 0, 1, 2] 

125 

126 # This might raise an error if negative powers aren't supported 

127 try: 

128 result = matrix_powers(A, powers) 

129 

130 # If it succeeds, verify the results 

131 unique_powers = sorted(set(powers)) 

132 for i, power in enumerate(unique_powers): 

133 expected = np.linalg.matrix_power(A, power) 

134 np.testing.assert_allclose(result[i], expected) 

135 except Exception as e: 

136 pytest.skip(f"Negative powers not supported: {e}") 

137 

138 def test_large_powers(self) -> None: 

139 """Test with large powers to verify binary exponentiation efficiency.""" 

140 # Matrix with eigenvalues < 1 to avoid overflow 

141 A = np.array([[0.5, 0.1], [0.1, 0.5]]) 

142 large_power = 1000 

143 powers = [large_power] 

144 

145 result = matrix_powers(A, powers) 

146 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy() 

147 expected = np.linalg.matrix_power(A, large_power) 

148 

149 np.testing.assert_allclose(result[0], expected) 

150 np.testing.assert_allclose(result_torch[0], expected) 

151 

152 def test_performance(self) -> None: 

153 """Test that binary exponentiation is more efficient than naive approach.""" 

154 import time 

155 

156 A = np.random.randn(64, 64) 

157 powers = [10, 100] + list(range(1000, 1024)) + list(range(9_000, 9_200)) 

158 

159 # Time our implementation 

160 start = time.time() 

161 p_np = matrix_powers(A, powers) 

162 our_time = time.time() - start 

163 

164 # Time torch implementation 

165 start = time.time() 

166 p_torch = matrix_powers_torch(torch.tensor(A), powers) 

167 torch_time = time.time() - start 

168 

169 # Time naive approach 

170 start = time.time() 

171 p_naive = [] 

172 for power in powers: 

173 p_naive.append(np.linalg.matrix_power(A, power)) 

174 naive_time = time.time() - start 

175 

176 assert len(p_np) == len(p_naive), "Output lengths do not match" 

177 assert len(p_torch) == len(p_naive), "Torch output lengths do not match" 

178 

179 # Our implementation should be faster for these powers 

180 if our_time >= naive_time: 

181 warnings.warn( 

182 f"Warning: Binary exponentiation with `matrix_powers()` ({our_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)" 

183 ) 

184 if torch_time >= naive_time: 

185 warnings.warn( 

186 f"Warning: Binary exponentiation with `matrix_powers_torch()` ({torch_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)" 

187 )