Coverage for tests/unit/math/test_matrix_powers.py: 100%
88 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-30 22:10 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-30 22:10 -0600
1from __future__ import annotations
3from typing import List, Tuple
4import numpy as np
5import pytest
6from jaxtyping import Float
7import torch
9from muutils.dbg import dbg_tensor
10from muutils.math.matrix_powers import matrix_powers, matrix_powers_torch
13class TestMatrixPowers:
14 @pytest.fixture
15 def sample_matrices(self) -> List[Tuple[str, Float[np.ndarray, "n n"]]]:
16 """Return a list of test matrices with diverse properties."""
17 return [
18 ("identity", np.eye(3)),
19 ("diagonal", np.diag([2.0, 3, 4])),
20 ("nilpotent", np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0.0]])),
21 # ("random_int", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
22 ("random_float", np.random.rand(4, 4)),
23 ("complex", np.array([[1 + 1j, 2], [3, 4 - 2j]])),
24 ]
26 @pytest.fixture
27 def power_test_cases(self) -> List[List[int]]:
28 """Return test cases for powers to compute."""
29 return [
30 [0],
31 [1],
32 [2],
33 [0, 1, 2],
34 [5],
35 [0, 1, 5, 10],
36 [1, 2, 4, 8, 16],
37 list(range(10)),
38 ]
40 def test_against_numpy_implementation(
41 self,
42 sample_matrices: List[Tuple[str, Float[np.ndarray, "n n"]]],
43 power_test_cases: List[List[int]],
44 ) -> None:
45 """Test that matrix_powers gives same results as numpy.linalg.matrix_power."""
46 for name, matrix in sample_matrices:
47 for powers in power_test_cases:
48 # Compute with our function
49 dbg_tensor(matrix)
50 result = matrix_powers(matrix, powers)
51 result_torch = (
52 matrix_powers_torch(torch.tensor(matrix), powers).cpu().numpy()
53 )
55 # Get dimension information
56 n_powers = len(set(powers))
57 dim_n = matrix.shape[0]
59 # Check shape of output
60 assert result.shape == (
61 n_powers,
62 dim_n,
63 dim_n,
64 ), f"Incorrect shape for {name} matrix with powers {powers}"
65 assert result_torch.shape == (
66 n_powers,
67 dim_n,
68 dim_n,
69 ), f"Incorrect shape for {name} matrix with powers {powers} (torch)"
71 # Compare with numpy implementation for each power
72 unique_powers = sorted(set(powers))
73 for i, power in enumerate(unique_powers):
74 expected = np.linalg.matrix_power(matrix, power)
75 np.testing.assert_allclose(
76 result[i],
77 expected,
78 rtol=1e-10,
79 atol=1e-10,
80 err_msg=f"Failed for {name} matrix to power {power}",
81 )
82 np.testing.assert_allclose(
83 result_torch[i],
84 expected,
85 rtol=1e-10,
86 atol=1e-10,
87 err_msg=f"Failed for {name} matrix to power {power}",
88 )
90 def test_empty_powers_list(self) -> None:
91 """Test handling of empty powers list."""
92 A = np.eye(3)
93 with pytest.raises(ValueError):
94 matrix_powers(A, [])
96 def test_duplicate_powers(self) -> None:
97 """Test handling of duplicate powers in the input list."""
98 A = np.diag([2.0, 3, 4])
99 powers = [1, 2, 2, 3, 1]
100 result = matrix_powers(A, powers)
101 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy()
103 # Should only have 3 unique powers
104 assert result.shape == (3, 3, 3)
106 # Check each power
107 unique_powers = sorted(set(powers))
108 for i, power in enumerate(unique_powers):
109 expected = np.linalg.matrix_power(A, power)
110 np.testing.assert_allclose(result[i], expected)
111 np.testing.assert_allclose(result_torch[i], expected)
113 def test_non_square_matrix(self) -> None:
114 """Test that an assertion error is raised for non-square matrices."""
115 A = np.ones((3, 4))
116 with pytest.raises(AssertionError):
117 matrix_powers(A, [1, 2])
119 def test_negative_powers(self) -> None:
120 """Test handling of negative powers (should work with invertible matrices)."""
121 # Use an invertible matrix
122 A = np.array([[1, 2], [3, 4]])
123 powers = [-1, -2, 0, 1, 2]
125 # This might raise an error if negative powers aren't supported
126 try:
127 result = matrix_powers(A, powers)
129 # If it succeeds, verify the results
130 unique_powers = sorted(set(powers))
131 for i, power in enumerate(unique_powers):
132 expected = np.linalg.matrix_power(A, power)
133 np.testing.assert_allclose(result[i], expected)
134 except Exception as e:
135 pytest.skip(f"Negative powers not supported: {e}")
137 def test_large_powers(self) -> None:
138 """Test with large powers to verify binary exponentiation efficiency."""
139 # Matrix with eigenvalues < 1 to avoid overflow
140 A = np.array([[0.5, 0.1], [0.1, 0.5]])
141 large_power = 1000
142 powers = [large_power]
144 result = matrix_powers(A, powers)
145 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy()
146 expected = np.linalg.matrix_power(A, large_power)
148 np.testing.assert_allclose(result[0], expected)
149 np.testing.assert_allclose(result_torch[0], expected)
151 def test_performance(self) -> None:
152 """Test that binary exponentiation is more efficient than naive approach."""
153 import time
155 A = np.random.randn(64, 64)
156 powers = [10, 100] + list(range(1000, 1024)) + list(range(9_000, 9_200))
158 # Time our implementation
159 start = time.time()
160 p_np = matrix_powers(A, powers)
161 our_time = time.time() - start
163 # Time torch implementation
164 start = time.time()
165 p_torch = matrix_powers_torch(torch.tensor(A), powers)
166 torch_time = time.time() - start
168 # Time naive approach
169 start = time.time()
170 p_naive = []
171 for power in powers:
172 p_naive.append(np.linalg.matrix_power(A, power))
173 naive_time = time.time() - start
175 assert len(p_np) == len(p_naive), "Output lengths do not match"
176 assert len(p_torch) == len(p_naive), "Torch output lengths do not match"
178 # Our implementation should be faster for these powers
179 assert (
180 our_time < naive_time
181 ), f"Binary exponentiation ({our_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)"
182 assert (
183 torch_time < naive_time
184 ), f"Binary exponentiation ({torch_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)"