Coverage for muutils/tensor_utils.py: 86%
133 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""utilities for working with tensors and arrays.
3notably:
5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
6- `DTYPE_MAP` mapping string representations of types to their type
7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
10"""
12from __future__ import annotations
14import json
15import typing
17import jaxtyping
18import numpy as np
19import torch
21from muutils.errormode import ErrorMode
22from muutils.dictmagic import dotlist_to_nested_dict
24# pylint: disable=missing-class-docstring
27TYPE_TO_JAX_DTYPE: dict = {
28 float: jaxtyping.Float,
29 int: jaxtyping.Int,
30 jaxtyping.Float: jaxtyping.Float,
31 jaxtyping.Int: jaxtyping.Int,
32 # bool
33 bool: jaxtyping.Bool,
34 jaxtyping.Bool: jaxtyping.Bool,
35 np.bool_: jaxtyping.Bool,
36 torch.bool: jaxtyping.Bool,
37 # numpy float
38 np.float16: jaxtyping.Float,
39 np.float32: jaxtyping.Float,
40 np.float64: jaxtyping.Float,
41 np.half: jaxtyping.Float,
42 np.single: jaxtyping.Float,
43 np.double: jaxtyping.Float,
44 # numpy int
45 np.int8: jaxtyping.Int,
46 np.int16: jaxtyping.Int,
47 np.int32: jaxtyping.Int,
48 np.int64: jaxtyping.Int,
49 np.longlong: jaxtyping.Int,
50 np.short: jaxtyping.Int,
51 np.uint8: jaxtyping.Int,
52 # torch float
53 torch.float: jaxtyping.Float,
54 torch.float16: jaxtyping.Float,
55 torch.float32: jaxtyping.Float,
56 torch.float64: jaxtyping.Float,
57 torch.half: jaxtyping.Float,
58 torch.double: jaxtyping.Float,
59 torch.bfloat16: jaxtyping.Float,
60 # torch int
61 torch.int: jaxtyping.Int,
62 torch.int8: jaxtyping.Int,
63 torch.int16: jaxtyping.Int,
64 torch.int32: jaxtyping.Int,
65 torch.int64: jaxtyping.Int,
66 torch.long: jaxtyping.Int,
67 torch.short: jaxtyping.Int,
68}
69"dict mapping python, numpy, and torch types to `jaxtyping` types"
71if np.version.version < "2.0.0":
72 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float
73 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int
76# TODO: add proper type annotations to this signature
77# TODO: maybe get rid of this altogether?
78def jaxtype_factory(
79 name: str,
80 array_type: type,
81 default_jax_dtype=jaxtyping.Float,
82 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
83) -> type:
84 """usage:
85 ```
86 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
87 x: ATensor["dim1 dim2", np.float32]
88 ```
89 """
90 legacy_mode_ = ErrorMode.from_any(legacy_mode)
92 class _BaseArray:
93 """jaxtyping shorthand
94 (backwards compatible with older versions of muutils.tensor_utils)
96 default_jax_dtype = {default_jax_dtype}
97 array_type = {array_type}
98 """
100 def __new__(cls, *args, **kwargs):
101 raise TypeError("Type FArray cannot be instantiated.")
103 def __init_subclass__(cls, *args, **kwargs):
104 raise TypeError(f"Cannot subclass {cls.__name__}")
106 @classmethod
107 def param_info(cls, params) -> str:
108 """useful for error printing"""
109 return "\n".join(
110 f"{k} = {v}"
111 for k, v in {
112 "cls.__name__": cls.__name__,
113 "cls.__doc__": cls.__doc__,
114 "params": params,
115 "type(params)": type(params),
116 }.items()
117 )
119 @typing._tp_cache # type: ignore
120 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore
121 # MyTensor["dim1 dim2"]
122 if isinstance(params, str):
123 return default_jax_dtype[array_type, params]
125 elif isinstance(params, tuple):
126 if len(params) != 2:
127 raise Exception(
128 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
129 )
131 if isinstance(params[0], str):
132 # MyTensor["dim1 dim2", int]
133 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
135 elif isinstance(params[0], tuple):
136 legacy_mode_.process(
137 f"legacy type annotation was used:\n{cls.param_info(params) = }",
138 except_cls=Exception,
139 )
140 # MyTensor[("dim1", "dim2"), int]
141 shape_anot: list[str] = list()
142 for x in params[0]:
143 if isinstance(x, str):
144 shape_anot.append(x)
145 elif isinstance(x, int):
146 shape_anot.append(str(x))
147 elif isinstance(x, tuple):
148 shape_anot.append("".join(str(y) for y in x))
149 else:
150 raise Exception(
151 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
152 )
154 return TYPE_TO_JAX_DTYPE[params[1]][
155 array_type, " ".join(shape_anot)
156 ]
157 else:
158 raise Exception(
159 f"unexpected type for params:\n{cls.param_info(params)}"
160 )
162 _BaseArray.__name__ = name
164 if _BaseArray.__doc__ is None:
165 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
167 _BaseArray.__doc__ = _BaseArray.__doc__.format(
168 default_jax_dtype=repr(default_jax_dtype),
169 array_type=repr(array_type),
170 )
172 return _BaseArray
175if typing.TYPE_CHECKING:
176 # these class definitions are only used here to make pylint happy,
177 # but they make mypy unhappy and there is no way to only run if not mypy
178 # so, later on we have more ignores
179 class ATensor(torch.Tensor):
180 @typing._tp_cache # type: ignore
181 def __class_getitem__(cls, params):
182 raise NotImplementedError()
184 class NDArray(torch.Tensor):
185 @typing._tp_cache # type: ignore
186 def __class_getitem__(cls, params):
187 raise NotImplementedError()
190ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment]
192NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment]
195def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
196 """convert numpy dtype to torch dtype"""
197 if isinstance(dtype, torch.dtype):
198 return dtype
199 else:
200 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
203DTYPE_LIST: list = [
204 *[
205 bool,
206 int,
207 float,
208 ],
209 *[
210 # ----------
211 # pytorch
212 # ----------
213 # floats
214 torch.float,
215 torch.float32,
216 torch.float64,
217 torch.half,
218 torch.double,
219 torch.bfloat16,
220 # complex
221 torch.complex64,
222 torch.complex128,
223 # ints
224 torch.int,
225 torch.int8,
226 torch.int16,
227 torch.int32,
228 torch.int64,
229 torch.long,
230 torch.short,
231 # simplest
232 torch.uint8,
233 torch.bool,
234 ],
235 *[
236 # ----------
237 # numpy
238 # ----------
239 # floats
240 np.float16,
241 np.float32,
242 np.float64,
243 np.half,
244 np.single,
245 np.double,
246 # complex
247 np.complex64,
248 np.complex128,
249 # ints
250 np.int8,
251 np.int16,
252 np.int32,
253 np.int64,
254 np.longlong,
255 np.short,
256 # simplest
257 np.uint8,
258 np.bool_,
259 ],
260]
261"list of all the python, numpy, and torch numerical types I could think of"
263if np.version.version < "2.0.0":
264 DTYPE_LIST.extend([np.float_, np.int_])
266DTYPE_MAP: dict = {
267 **{str(x): x for x in DTYPE_LIST},
268 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
269}
270"mapping from string representations of types to their type"
272TORCH_DTYPE_MAP: dict = {
273 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
274}
275"mapping from string representations of types to specifically torch types"
277# no idea why we have to do this, smh
278DTYPE_MAP["bool"] = np.bool_
279TORCH_DTYPE_MAP["bool"] = torch.bool
282TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
283 "Adagrad": torch.optim.Adagrad,
284 "Adam": torch.optim.Adam,
285 "AdamW": torch.optim.AdamW,
286 "SparseAdam": torch.optim.SparseAdam,
287 "Adamax": torch.optim.Adamax,
288 "ASGD": torch.optim.ASGD,
289 "LBFGS": torch.optim.LBFGS,
290 "NAdam": torch.optim.NAdam,
291 "RAdam": torch.optim.RAdam,
292 "RMSprop": torch.optim.RMSprop,
293 "Rprop": torch.optim.Rprop,
294 "SGD": torch.optim.SGD,
295}
298def pad_tensor(
299 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821
300 padded_length: int,
301 pad_value: float = 0.0,
302 rpad: bool = False,
303) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821
304 """pad a 1-d tensor on the left with pad_value to length `padded_length`
306 set `rpad = True` to pad on the right instead"""
308 temp: list[torch.Tensor] = [
309 torch.full(
310 (padded_length - tensor.shape[0],),
311 pad_value,
312 dtype=tensor.dtype,
313 device=tensor.device,
314 ),
315 tensor,
316 ]
318 if rpad:
319 temp.reverse()
321 return torch.cat(temp)
324def lpad_tensor(
325 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
326) -> torch.Tensor:
327 """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
328 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
331def rpad_tensor(
332 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
333) -> torch.Tensor:
334 """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
335 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
338def pad_array(
339 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821
340 padded_length: int,
341 pad_value: float = 0.0,
342 rpad: bool = False,
343) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821
344 """pad a 1-d array on the left with pad_value to length `padded_length`
346 set `rpad = True` to pad on the right instead"""
348 temp: list[np.ndarray] = [
349 np.full(
350 (padded_length - array.shape[0],),
351 pad_value,
352 dtype=array.dtype,
353 ),
354 array,
355 ]
357 if rpad:
358 temp.reverse()
360 return np.concatenate(temp)
363def lpad_array(
364 array: np.ndarray, padded_length: int, pad_value: float = 0.0
365) -> np.ndarray:
366 """pad a 1-d array on the left with pad_value to length `padded_length`"""
367 return pad_array(array, padded_length, pad_value, rpad=False)
370def rpad_array(
371 array: np.ndarray, pad_length: int, pad_value: float = 0.0
372) -> np.ndarray:
373 """pad a 1-d array on the right with pad_value to length `pad_length`"""
374 return pad_array(array, pad_length, pad_value, rpad=True)
377def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
378 """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
379 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
382def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
383 """printable version of get_dict_shapes"""
384 return json.dumps(
385 dotlist_to_nested_dict(
386 {
387 k: str(
388 tuple(v.shape)
389 ) # to string, since indent wont play nice with tuples
390 for k, v in d.items()
391 }
392 ),
393 indent=2,
394 )
397class StateDictCompareError(AssertionError):
398 """raised when state dicts don't match"""
400 pass
403class StateDictKeysError(StateDictCompareError):
404 """raised when state dict keys don't match"""
406 pass
409class StateDictShapeError(StateDictCompareError):
410 """raised when state dict shapes don't match"""
412 pass
415class StateDictValueError(StateDictCompareError):
416 """raised when state dict values don't match"""
418 pass
421def compare_state_dicts(
422 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
423) -> None:
424 """compare two dicts of tensors
426 # Parameters:
428 - `d1 : dict`
429 - `d2 : dict`
430 - `rtol : float`
431 (defaults to `1e-5`)
432 - `atol : float`
433 (defaults to `1e-8`)
434 - `verbose : bool`
435 (defaults to `True`)
437 # Raises:
439 - `StateDictKeysError` : keys don't match
440 - `StateDictShapeError` : shapes don't match (but keys do)
441 - `StateDictValueError` : values don't match (but keys and shapes do)
442 """
443 # check keys match
444 d1_keys: set = set(d1.keys())
445 d2_keys: set = set(d2.keys())
446 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
447 keys_diff_1: set = d1_keys - d2_keys
448 keys_diff_2: set = d2_keys - d1_keys
449 # sort sets for easier debugging
450 symmetric_diff = set(sorted(symmetric_diff))
451 keys_diff_1 = set(sorted(keys_diff_1))
452 keys_diff_2 = set(sorted(keys_diff_2))
453 diff_shapes_1: str = (
454 string_dict_shapes({k: d1[k] for k in keys_diff_1})
455 if verbose
456 else "(verbose = False)"
457 )
458 diff_shapes_2: str = (
459 string_dict_shapes({k: d2[k] for k in keys_diff_2})
460 if verbose
461 else "(verbose = False)"
462 )
463 if not len(symmetric_diff) == 0:
464 raise StateDictKeysError(
465 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
466 )
468 # check tensors match
469 shape_failed: list[str] = list()
470 vals_failed: list[str] = list()
471 for k, v1 in d1.items():
472 v2 = d2[k]
473 # check shapes first
474 if not v1.shape == v2.shape:
475 shape_failed.append(k)
476 else:
477 # if shapes match, check values
478 if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
479 vals_failed.append(k)
481 str_shape_failed: str = (
482 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
483 )
484 str_vals_failed: str = (
485 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
486 )
488 if not len(shape_failed) == 0:
489 raise StateDictShapeError(
490 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
491 )
492 if not len(vals_failed) == 0:
493 raise StateDictValueError(
494 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
495 )