Coverage for muutils/math/matrix_powers.py: 96%
67 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-30 22:10 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-30 22:10 -0600
1from __future__ import annotations
3from typing import List, Sequence, TYPE_CHECKING
5import numpy as np
6from jaxtyping import Float, Int
8if TYPE_CHECKING:
9 pass
12def matrix_powers(
13 A: Float[np.ndarray, "n n"],
14 powers: Sequence[int],
15) -> Float[np.ndarray, "n_powers n n"]:
16 """Compute multiple powers of a matrix efficiently.
18 Uses binary exponentiation to compute powers in O(log max(powers))
19 matrix multiplications, avoiding redundant calculations when
20 computing multiple powers.
22 # Parameters:
23 - `A : Float[np.ndarray, "n n"]`
24 Square matrix to exponentiate
25 - `powers : Sequence[int]`
26 List of powers to compute (non-negative integers)
28 # Returns:
29 - `dict[int, Float[np.ndarray, "n n"]]`
30 Dictionary mapping each requested power to the corresponding matrix power
31 """
32 dim_n: int = A.shape[0]
33 assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }"
34 powers_np: Int[np.ndarray, "n_powers_unique"] = np.array(
35 sorted(set(powers)), dtype=int
36 )
37 n_powers_unique: int = len(powers_np)
39 if n_powers_unique < 1:
40 raise ValueError(f"No powers requested: {powers = }")
42 output: Float[np.ndarray, "n_powers_unique n n"] = np.full(
43 (n_powers_unique, dim_n, dim_n),
44 fill_value=np.nan,
45 dtype=A.dtype,
46 )
48 # Find the maximum power to compute
49 max_power: int = max(powers_np)
51 # Precompute all powers of 2 up to the largest power needed
52 # This forms our basis for binary decomposition
53 powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {}
54 powers_of_two[0] = np.eye(dim_n, dtype=A.dtype)
55 powers_of_two[1] = A.copy()
57 # Compute powers of 2: A^2, A^4, A^8, ...
58 p: int = 1
59 while p < max_power:
60 if p <= max_power:
61 A_power_p = powers_of_two[p]
62 powers_of_two[p * 2] = A_power_p @ A_power_p
63 p = p * 2
65 # For each requested power, compute it using the powers of 2
66 for p_idx, power in enumerate(powers_np):
67 # Decompose power into sum of powers of 2
68 temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy()
69 temp_power: int = power
70 p_temp: int = 1
72 while temp_power > 0:
73 if temp_power % 2 == 1:
74 temp_result = temp_result @ powers_of_two[p_temp]
75 temp_power = temp_power // 2
76 p_temp *= 2
78 output[p_idx] = temp_result
80 return output
83# BUG: breaks with integer matrices???
84# TYPING: jaxtyping hints not working here, separate file for torch implementation?
85def matrix_powers_torch(
86 A, # : Float["torch.Tensor", "n n"],
87 powers: Sequence[int],
88): # Float["torch.Tensor", "n_powers n n"]:
89 """Compute multiple powers of a matrix efficiently.
91 Uses binary exponentiation to compute powers in O(log max(powers))
92 matrix multiplications, avoiding redundant calculations when
93 computing multiple powers.
95 # Parameters:
96 - `A : Float[torch.Tensor, "n n"]`
97 Square matrix to exponentiate
98 - `powers : Sequence[int]`
99 List of powers to compute (non-negative integers)
101 # Returns:
102 - `Float[torch.Tensor, "n_powers n n"]`
103 Tensor containing the requested matrix powers stacked along the first dimension
105 # Raises:
106 - `ValueError` : If no powers are requested or if A is not a square matrix
107 """
109 import torch
111 if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
112 raise ValueError(f"Matrix must be square, but got {A.shape = }")
114 dim_n: int = A.shape[0]
115 # Get unique powers and sort them
116 unique_powers: List[int] = sorted(set(powers))
117 n_powers_unique: int = len(unique_powers)
118 powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor(
119 unique_powers, dtype=torch.int64, device=A.device
120 )
122 if n_powers_unique < 1:
123 raise ValueError(f"No powers requested: {powers = }")
125 output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full(
126 (n_powers_unique, dim_n, dim_n),
127 float("nan"),
128 dtype=A.dtype,
129 device=A.device,
130 )
132 # Find the maximum power to compute
133 max_power: int = int(powers_tensor.max().item())
135 # Precompute all powers of 2 up to the largest power needed
136 # This forms our basis for binary decomposition
137 powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {}
138 powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device)
139 powers_of_two[1] = A.clone()
141 # Compute powers of 2: A^2, A^4, A^8, ...
142 p: int = 1
143 while p < max_power:
144 if p <= max_power:
145 A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p]
146 powers_of_two[p * 2] = A_power_p @ A_power_p
147 p = p * 2
149 # For each requested power, compute it using the powers of 2
150 for p_idx, power in enumerate(unique_powers):
151 # Decompose power into sum of powers of 2
152 temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone()
153 temp_power: int = power
154 p_temp: int = 1
156 while temp_power > 0:
157 if temp_power % 2 == 1:
158 temp_result = temp_result @ powers_of_two[p_temp]
159 temp_power = temp_power // 2
160 p_temp *= 2
162 output[p_idx] = temp_result
164 return output