Coverage for tests/unit/math/test_matrix_powers_torch.py: 98%
91 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
1from __future__ import annotations
3from typing import List, Tuple
4import warnings
5import numpy as np
6import pytest
7from jaxtyping import Float
8import torch
10from muutils.dbg import dbg_tensor
11from muutils.math.matrix_powers import matrix_powers, matrix_powers_torch
14class TestMatrixPowers:
15 @pytest.fixture
16 def sample_matrices(self) -> List[Tuple[str, Float[np.ndarray, "n n"]]]:
17 """Return a list of test matrices with diverse properties."""
18 return [
19 ("identity", np.eye(3)),
20 ("diagonal", np.diag([2.0, 3, 4])),
21 ("nilpotent", np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0.0]])),
22 # ("random_int", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
23 ("random_float", np.random.rand(4, 4)),
24 ("complex", np.array([[1 + 1j, 2], [3, 4 - 2j]])),
25 ]
27 @pytest.fixture
28 def power_test_cases(self) -> List[List[int]]:
29 """Return test cases for powers to compute."""
30 return [
31 [0],
32 [1],
33 [2],
34 [0, 1, 2],
35 [5],
36 [0, 1, 5, 10],
37 [1, 2, 4, 8, 16],
38 list(range(10)),
39 ]
41 def test_against_numpy_implementation(
42 self,
43 sample_matrices: List[Tuple[str, Float[np.ndarray, "n n"]]],
44 power_test_cases: List[List[int]],
45 ) -> None:
46 """Test that matrix_powers gives same results as numpy.linalg.matrix_power."""
47 for name, matrix in sample_matrices:
48 for powers in power_test_cases:
49 # Compute with our function
50 dbg_tensor(matrix)
51 result = matrix_powers(matrix, powers)
52 result_torch = (
53 matrix_powers_torch(torch.tensor(matrix), powers).cpu().numpy()
54 )
56 # Get dimension information
57 n_powers = len(set(powers))
58 dim_n = matrix.shape[0]
60 # Check shape of output
61 assert result.shape == (
62 n_powers,
63 dim_n,
64 dim_n,
65 ), f"Incorrect shape for {name} matrix with powers {powers}"
66 assert result_torch.shape == (
67 n_powers,
68 dim_n,
69 dim_n,
70 ), f"Incorrect shape for {name} matrix with powers {powers} (torch)"
72 # Compare with numpy implementation for each power
73 unique_powers = sorted(set(powers))
74 for i, power in enumerate(unique_powers):
75 expected = np.linalg.matrix_power(matrix, power)
76 np.testing.assert_allclose(
77 result[i],
78 expected,
79 rtol=1e-10,
80 atol=1e-10,
81 err_msg=f"Failed for {name} matrix to power {power}",
82 )
83 np.testing.assert_allclose(
84 result_torch[i],
85 expected,
86 rtol=1e-10,
87 atol=1e-10,
88 err_msg=f"Failed for {name} matrix to power {power}",
89 )
91 def test_empty_powers_list(self) -> None:
92 """Test handling of empty powers list."""
93 A = np.eye(3)
94 with pytest.raises(ValueError):
95 matrix_powers(A, [])
97 def test_duplicate_powers(self) -> None:
98 """Test handling of duplicate powers in the input list."""
99 A = np.diag([2.0, 3, 4])
100 powers = [1, 2, 2, 3, 1]
101 result = matrix_powers(A, powers)
102 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy()
104 # Should only have 3 unique powers
105 assert result.shape == (3, 3, 3)
107 # Check each power
108 unique_powers = sorted(set(powers))
109 for i, power in enumerate(unique_powers):
110 expected = np.linalg.matrix_power(A, power)
111 np.testing.assert_allclose(result[i], expected)
112 np.testing.assert_allclose(result_torch[i], expected)
114 def test_non_square_matrix(self) -> None:
115 """Test that an assertion error is raised for non-square matrices."""
116 A = np.ones((3, 4))
117 with pytest.raises(AssertionError):
118 matrix_powers(A, [1, 2])
120 def test_negative_powers(self) -> None:
121 """Test handling of negative powers (should work with invertible matrices)."""
122 # Use an invertible matrix
123 A = np.array([[1, 2], [3, 4]])
124 powers = [-1, -2, 0, 1, 2]
126 # This might raise an error if negative powers aren't supported
127 try:
128 result = matrix_powers(A, powers)
130 # If it succeeds, verify the results
131 unique_powers = sorted(set(powers))
132 for i, power in enumerate(unique_powers):
133 expected = np.linalg.matrix_power(A, power)
134 np.testing.assert_allclose(result[i], expected)
135 except Exception as e:
136 pytest.skip(f"Negative powers not supported: {e}")
138 def test_large_powers(self) -> None:
139 """Test with large powers to verify binary exponentiation efficiency."""
140 # Matrix with eigenvalues < 1 to avoid overflow
141 A = np.array([[0.5, 0.1], [0.1, 0.5]])
142 large_power = 1000
143 powers = [large_power]
145 result = matrix_powers(A, powers)
146 result_torch = matrix_powers_torch(torch.tensor(A), powers).cpu().numpy()
147 expected = np.linalg.matrix_power(A, large_power)
149 np.testing.assert_allclose(result[0], expected)
150 np.testing.assert_allclose(result_torch[0], expected)
152 def test_performance(self) -> None:
153 """Test that binary exponentiation is more efficient than naive approach."""
154 import time
156 A = np.random.randn(64, 64)
157 powers = [10, 100] + list(range(1000, 1024)) + list(range(9_000, 9_200))
159 # Time our implementation
160 start = time.time()
161 p_np = matrix_powers(A, powers)
162 our_time = time.time() - start
164 # Time torch implementation
165 start = time.time()
166 p_torch = matrix_powers_torch(torch.tensor(A), powers)
167 torch_time = time.time() - start
169 # Time naive approach
170 start = time.time()
171 p_naive = []
172 for power in powers:
173 p_naive.append(np.linalg.matrix_power(A, power))
174 naive_time = time.time() - start
176 assert len(p_np) == len(p_naive), "Output lengths do not match"
177 assert len(p_torch) == len(p_naive), "Torch output lengths do not match"
179 # Our implementation should be faster for these powers
180 if our_time >= naive_time:
181 warnings.warn(
182 f"Warning: Binary exponentiation with `matrix_powers()` ({our_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)"
183 )
184 if torch_time >= naive_time:
185 warnings.warn(
186 f"Warning: Binary exponentiation with `matrix_powers_torch()` ({torch_time:.4f}s) not faster than naive approach ({naive_time:.4f}s)"
187 )