Coverage for muutils/tensor_utils.py: 90%
124 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
1"""utilities for working with tensors and arrays.
3notably:
5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
6- `DTYPE_MAP` mapping string representations of types to their type
7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
10"""
12from __future__ import annotations
14import json
15import typing
17import jaxtyping
18import numpy as np
19import torch
21from muutils.errormode import ErrorMode
22from muutils.dictmagic import dotlist_to_nested_dict
24# pylint: disable=missing-class-docstring
27TYPE_TO_JAX_DTYPE: dict = {
28 float: jaxtyping.Float,
29 int: jaxtyping.Int,
30 jaxtyping.Float: jaxtyping.Float,
31 jaxtyping.Int: jaxtyping.Int,
32 # bool
33 bool: jaxtyping.Bool,
34 jaxtyping.Bool: jaxtyping.Bool,
35 np.bool_: jaxtyping.Bool,
36 torch.bool: jaxtyping.Bool,
37 # numpy float
38 np.float16: jaxtyping.Float,
39 np.float32: jaxtyping.Float,
40 np.float64: jaxtyping.Float,
41 np.half: jaxtyping.Float,
42 np.single: jaxtyping.Float,
43 np.double: jaxtyping.Float,
44 # numpy int
45 np.int8: jaxtyping.Int,
46 np.int16: jaxtyping.Int,
47 np.int32: jaxtyping.Int,
48 np.int64: jaxtyping.Int,
49 np.longlong: jaxtyping.Int,
50 np.short: jaxtyping.Int,
51 np.uint8: jaxtyping.Int,
52 # torch float
53 torch.float: jaxtyping.Float,
54 torch.float16: jaxtyping.Float,
55 torch.float32: jaxtyping.Float,
56 torch.float64: jaxtyping.Float,
57 torch.half: jaxtyping.Float,
58 torch.double: jaxtyping.Float,
59 torch.bfloat16: jaxtyping.Float,
60 # torch int
61 torch.int: jaxtyping.Int,
62 torch.int8: jaxtyping.Int,
63 torch.int16: jaxtyping.Int,
64 torch.int32: jaxtyping.Int,
65 torch.int64: jaxtyping.Int,
66 torch.long: jaxtyping.Int,
67 torch.short: jaxtyping.Int,
68}
69"dict mapping python, numpy, and torch types to `jaxtyping` types"
71# we check for version here, so it shouldn't error
72if np.version.version < "2.0.0":
73 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined]
74 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined]
77# TODO: add proper type annotations to this signature
78# TODO: maybe get rid of this altogether?
79def jaxtype_factory(
80 name: str,
81 array_type: type,
82 default_jax_dtype=jaxtyping.Float,
83 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
84) -> type:
85 """usage:
86 ```
87 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
88 x: ATensor["dim1 dim2", np.float32]
89 ```
90 """
91 legacy_mode_ = ErrorMode.from_any(legacy_mode)
93 class _BaseArray:
94 """jaxtyping shorthand
95 (backwards compatible with older versions of muutils.tensor_utils)
97 default_jax_dtype = {default_jax_dtype}
98 array_type = {array_type}
99 """
101 def __new__(cls, *args, **kwargs):
102 raise TypeError("Type FArray cannot be instantiated.")
104 def __init_subclass__(cls, *args, **kwargs):
105 raise TypeError(f"Cannot subclass {cls.__name__}")
107 @classmethod
108 def param_info(cls, params) -> str:
109 """useful for error printing"""
110 return "\n".join(
111 f"{k} = {v}"
112 for k, v in {
113 "cls.__name__": cls.__name__,
114 "cls.__doc__": cls.__doc__,
115 "params": params,
116 "type(params)": type(params),
117 }.items()
118 )
120 @typing._tp_cache # type: ignore
121 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore
122 # MyTensor["dim1 dim2"]
123 if isinstance(params, str):
124 return default_jax_dtype[array_type, params]
126 elif isinstance(params, tuple):
127 if len(params) != 2:
128 raise Exception(
129 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
130 )
132 if isinstance(params[0], str):
133 # MyTensor["dim1 dim2", int]
134 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
136 elif isinstance(params[0], tuple):
137 legacy_mode_.process(
138 f"legacy type annotation was used:\n{cls.param_info(params) = }",
139 except_cls=Exception,
140 )
141 # MyTensor[("dim1", "dim2"), int]
142 shape_anot: list[str] = list()
143 for x in params[0]:
144 if isinstance(x, str):
145 shape_anot.append(x)
146 elif isinstance(x, int):
147 shape_anot.append(str(x))
148 elif isinstance(x, tuple):
149 shape_anot.append("".join(str(y) for y in x))
150 else:
151 raise Exception(
152 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
153 )
155 return TYPE_TO_JAX_DTYPE[params[1]][
156 array_type, " ".join(shape_anot)
157 ]
158 else:
159 raise Exception(
160 f"unexpected type for params:\n{cls.param_info(params)}"
161 )
163 _BaseArray.__name__ = name
165 if _BaseArray.__doc__ is None:
166 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
168 _BaseArray.__doc__ = _BaseArray.__doc__.format(
169 default_jax_dtype=repr(default_jax_dtype),
170 array_type=repr(array_type),
171 )
173 return _BaseArray
176if typing.TYPE_CHECKING:
177 # these class definitions are only used here to make pylint happy,
178 # but they make mypy unhappy and there is no way to only run if not mypy
179 # so, later on we have more ignores
180 class ATensor(torch.Tensor):
181 @typing._tp_cache # type: ignore
182 def __class_getitem__(cls, params):
183 raise NotImplementedError()
185 class NDArray(torch.Tensor):
186 @typing._tp_cache # type: ignore
187 def __class_getitem__(cls, params):
188 raise NotImplementedError()
191ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment]
193NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment]
196def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
197 """convert numpy dtype to torch dtype"""
198 if isinstance(dtype, torch.dtype):
199 return dtype
200 else:
201 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
204DTYPE_LIST: list = [
205 *[
206 bool,
207 int,
208 float,
209 ],
210 *[
211 # ----------
212 # pytorch
213 # ----------
214 # floats
215 torch.float,
216 torch.float32,
217 torch.float64,
218 torch.half,
219 torch.double,
220 torch.bfloat16,
221 # complex
222 torch.complex64,
223 torch.complex128,
224 # ints
225 torch.int,
226 torch.int8,
227 torch.int16,
228 torch.int32,
229 torch.int64,
230 torch.long,
231 torch.short,
232 # simplest
233 torch.uint8,
234 torch.bool,
235 ],
236 *[
237 # ----------
238 # numpy
239 # ----------
240 # floats
241 np.float16,
242 np.float32,
243 np.float64,
244 np.half,
245 np.single,
246 np.double,
247 # complex
248 np.complex64,
249 np.complex128,
250 # ints
251 np.int8,
252 np.int16,
253 np.int32,
254 np.int64,
255 np.longlong,
256 np.short,
257 # simplest
258 np.uint8,
259 np.bool_,
260 ],
261]
262"list of all the python, numpy, and torch numerical types I could think of"
264if np.version.version < "2.0.0":
265 DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined]
267DTYPE_MAP: dict = {
268 **{str(x): x for x in DTYPE_LIST},
269 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
270}
271"mapping from string representations of types to their type"
273TORCH_DTYPE_MAP: dict = {
274 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
275}
276"mapping from string representations of types to specifically torch types"
278# no idea why we have to do this, smh
279DTYPE_MAP["bool"] = np.bool_
280TORCH_DTYPE_MAP["bool"] = torch.bool
283TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
284 "Adagrad": torch.optim.Adagrad,
285 "Adam": torch.optim.Adam,
286 "AdamW": torch.optim.AdamW,
287 "SparseAdam": torch.optim.SparseAdam,
288 "Adamax": torch.optim.Adamax,
289 "ASGD": torch.optim.ASGD,
290 "LBFGS": torch.optim.LBFGS,
291 "NAdam": torch.optim.NAdam,
292 "RAdam": torch.optim.RAdam,
293 "RMSprop": torch.optim.RMSprop,
294 "Rprop": torch.optim.Rprop,
295 "SGD": torch.optim.SGD,
296}
299def pad_tensor(
300 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821
301 padded_length: int,
302 pad_value: float = 0.0,
303 rpad: bool = False,
304) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821
305 """pad a 1-d tensor on the left with pad_value to length `padded_length`
307 set `rpad = True` to pad on the right instead"""
309 temp: list[torch.Tensor] = [
310 torch.full(
311 (padded_length - tensor.shape[0],),
312 pad_value,
313 dtype=tensor.dtype,
314 device=tensor.device,
315 ),
316 tensor,
317 ]
319 if rpad:
320 temp.reverse()
322 return torch.cat(temp)
325def lpad_tensor(
326 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
327) -> torch.Tensor:
328 """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
329 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
332def rpad_tensor(
333 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
334) -> torch.Tensor:
335 """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
336 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
339def pad_array(
340 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821
341 padded_length: int,
342 pad_value: float = 0.0,
343 rpad: bool = False,
344) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821
345 """pad a 1-d array on the left with pad_value to length `padded_length`
347 set `rpad = True` to pad on the right instead"""
349 temp: list[np.ndarray] = [
350 np.full(
351 (padded_length - array.shape[0],),
352 pad_value,
353 dtype=array.dtype,
354 ),
355 array,
356 ]
358 if rpad:
359 temp.reverse()
361 return np.concatenate(temp)
364def lpad_array(
365 array: np.ndarray, padded_length: int, pad_value: float = 0.0
366) -> np.ndarray:
367 """pad a 1-d array on the left with pad_value to length `padded_length`"""
368 return pad_array(array, padded_length, pad_value, rpad=False)
371def rpad_array(
372 array: np.ndarray, pad_length: int, pad_value: float = 0.0
373) -> np.ndarray:
374 """pad a 1-d array on the right with pad_value to length `pad_length`"""
375 return pad_array(array, pad_length, pad_value, rpad=True)
378def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
379 """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
380 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
383def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
384 """printable version of get_dict_shapes"""
385 return json.dumps(
386 dotlist_to_nested_dict(
387 {
388 k: str(
389 tuple(v.shape)
390 ) # to string, since indent wont play nice with tuples
391 for k, v in d.items()
392 }
393 ),
394 indent=2,
395 )
398class StateDictCompareError(AssertionError):
399 """raised when state dicts don't match"""
401 pass
404class StateDictKeysError(StateDictCompareError):
405 """raised when state dict keys don't match"""
407 pass
410class StateDictShapeError(StateDictCompareError):
411 """raised when state dict shapes don't match"""
413 pass
416class StateDictValueError(StateDictCompareError):
417 """raised when state dict values don't match"""
419 pass
422def compare_state_dicts(
423 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
424) -> None:
425 """compare two dicts of tensors
427 # Parameters:
429 - `d1 : dict`
430 - `d2 : dict`
431 - `rtol : float`
432 (defaults to `1e-5`)
433 - `atol : float`
434 (defaults to `1e-8`)
435 - `verbose : bool`
436 (defaults to `True`)
438 # Raises:
440 - `StateDictKeysError` : keys don't match
441 - `StateDictShapeError` : shapes don't match (but keys do)
442 - `StateDictValueError` : values don't match (but keys and shapes do)
443 """
444 # check keys match
445 d1_keys: set = set(d1.keys())
446 d2_keys: set = set(d2.keys())
447 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
448 keys_diff_1: set = d1_keys - d2_keys
449 keys_diff_2: set = d2_keys - d1_keys
450 # sort sets for easier debugging
451 symmetric_diff = set(sorted(symmetric_diff))
452 keys_diff_1 = set(sorted(keys_diff_1))
453 keys_diff_2 = set(sorted(keys_diff_2))
454 diff_shapes_1: str = (
455 string_dict_shapes({k: d1[k] for k in keys_diff_1})
456 if verbose
457 else "(verbose = False)"
458 )
459 diff_shapes_2: str = (
460 string_dict_shapes({k: d2[k] for k in keys_diff_2})
461 if verbose
462 else "(verbose = False)"
463 )
464 if not len(symmetric_diff) == 0:
465 raise StateDictKeysError(
466 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
467 )
469 # check tensors match
470 shape_failed: list[str] = list()
471 vals_failed: list[str] = list()
472 for k, v1 in d1.items():
473 v2 = d2[k]
474 # check shapes first
475 if not v1.shape == v2.shape:
476 shape_failed.append(k)
477 else:
478 # if shapes match, check values
479 if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
480 vals_failed.append(k)
482 str_shape_failed: str = (
483 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
484 )
485 str_vals_failed: str = (
486 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
487 )
489 if not len(shape_failed) == 0:
490 raise StateDictShapeError(
491 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
492 )
493 if not len(vals_failed) == 0:
494 raise StateDictValueError(
495 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
496 )