Coverage for muutils / tensor_utils.py: 99%
89 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1"""utilities for working with tensors and arrays.
3notably:
5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
6- `DTYPE_MAP` mapping string representations of types to their type
7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
10"""
12from __future__ import annotations
14import json
15import typing
16from typing import Any
18import jaxtyping
19import numpy as np
20import torch
22from muutils.dictmagic import dotlist_to_nested_dict
24# pylint: disable=missing-class-docstring
27TYPE_TO_JAX_DTYPE: dict[Any, Any] = {
28 float: jaxtyping.Float,
29 int: jaxtyping.Int,
30 jaxtyping.Float: jaxtyping.Float,
31 jaxtyping.Int: jaxtyping.Int,
32 # bool
33 bool: jaxtyping.Bool,
34 jaxtyping.Bool: jaxtyping.Bool,
35 np.bool_: jaxtyping.Bool,
36 torch.bool: jaxtyping.Bool,
37 # numpy float
38 np.float16: jaxtyping.Float,
39 np.float32: jaxtyping.Float,
40 np.float64: jaxtyping.Float,
41 np.half: jaxtyping.Float,
42 np.single: jaxtyping.Float,
43 np.double: jaxtyping.Float,
44 # numpy int
45 np.int8: jaxtyping.Int,
46 np.int16: jaxtyping.Int,
47 np.int32: jaxtyping.Int,
48 np.int64: jaxtyping.Int,
49 np.longlong: jaxtyping.Int,
50 np.short: jaxtyping.Int,
51 np.uint8: jaxtyping.Int,
52 # torch float
53 torch.float: jaxtyping.Float,
54 torch.float16: jaxtyping.Float,
55 torch.float32: jaxtyping.Float,
56 torch.float64: jaxtyping.Float,
57 torch.half: jaxtyping.Float,
58 torch.double: jaxtyping.Float,
59 torch.bfloat16: jaxtyping.Float,
60 # torch int
61 torch.int: jaxtyping.Int,
62 torch.int8: jaxtyping.Int,
63 torch.int16: jaxtyping.Int,
64 torch.int32: jaxtyping.Int,
65 torch.int64: jaxtyping.Int,
66 torch.long: jaxtyping.Int,
67 torch.short: jaxtyping.Int,
68}
69"dict mapping python, numpy, and torch types to `jaxtyping` types"
71# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0
72# use try/except for backwards compatibility and type checker friendliness
73try:
74 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
75 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
76except AttributeError:
77 pass # numpy 2.0+ removed these deprecated aliases
80# TODO: add proper type annotations to this signature
81# TODO: maybe get rid of this altogether?
82# def jaxtype_factory(
83# name: str,
84# array_type: type,
85# default_jax_dtype: type[jaxtyping.Float | jaxtyping.Int | jaxtyping.Bool] = jaxtyping.Float,
86# legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
87# ) -> type:
88# """usage:
89# ```
90# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
91# x: ATensor["dim1 dim2", np.float32]
92# ```
93# """
94# legacy_mode_ = ErrorMode.from_any(legacy_mode)
96# class _BaseArray:
97# """jaxtyping shorthand
98# (backwards compatible with older versions of muutils.tensor_utils)
100# default_jax_dtype = {default_jax_dtype}
101# array_type = {array_type}
102# """
104# def __new__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn:
105# raise TypeError("Type FArray cannot be instantiated.")
107# def __init_subclass__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn:
108# raise TypeError(f"Cannot subclass {cls.__name__}")
110# @classmethod
111# def param_info(cls, params: typing.Union[str, tuple[Any, ...]]) -> str:
112# """useful for error printing"""
113# return "\n".join(
114# f"{k} = {v}"
115# for k, v in {
116# "cls.__name__": cls.__name__,
117# "cls.__doc__": cls.__doc__,
118# "params": params,
119# "type(params)": type(params),
120# }.items()
121# )
123# @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
124# def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: # type: ignore[misc]
125# # MyTensor["dim1 dim2"]
126# if isinstance(params, str):
127# return default_jax_dtype[array_type, params]
129# elif isinstance(params, tuple):
130# if len(params) != 2:
131# raise Exception(
132# f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
133# )
135# if isinstance(params[0], str):
136# # MyTensor["dim1 dim2", int]
137# return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
139# elif isinstance(params[0], tuple):
140# legacy_mode_.process(
141# f"legacy type annotation was used:\n{cls.param_info(params) = }",
142# except_cls=Exception,
143# )
144# # MyTensor[("dim1", "dim2"), int]
145# shape_anot: list[str] = list()
146# for x in params[0]:
147# if isinstance(x, str):
148# shape_anot.append(x)
149# elif isinstance(x, int):
150# shape_anot.append(str(x))
151# elif isinstance(x, tuple):
152# shape_anot.append("".join(str(y) for y in x))
153# else:
154# raise Exception(
155# f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
156# )
158# return TYPE_TO_JAX_DTYPE[params[1]][
159# array_type, " ".join(shape_anot)
160# ]
161# else:
162# raise Exception(
163# f"unexpected type for params:\n{cls.param_info(params)}"
164# )
166# _BaseArray.__name__ = name
168# if _BaseArray.__doc__ is None:
169# _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
171# _BaseArray.__doc__ = _BaseArray.__doc__.format(
172# default_jax_dtype=repr(default_jax_dtype),
173# array_type=repr(array_type),
174# )
176# return _BaseArray
179if typing.TYPE_CHECKING:
180 # these class definitions are only used here to make pylint happy,
181 # but they make mypy unhappy and there is no way to only run if not mypy
182 # so, later on we have more ignores
183 class ATensor(torch.Tensor):
184 @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
185 def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type:
186 raise NotImplementedError()
188 class NDArray(torch.Tensor):
189 @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
190 def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type:
191 raise NotImplementedError()
194# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment]
196# NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment]
199def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
200 """convert numpy dtype to torch dtype"""
201 if isinstance(dtype, torch.dtype):
202 return dtype
203 else:
204 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
207DTYPE_LIST: list[Any] = [
208 *[
209 bool,
210 int,
211 float,
212 ],
213 *[
214 # ----------
215 # pytorch
216 # ----------
217 # floats
218 torch.float,
219 torch.float32,
220 torch.float64,
221 torch.half,
222 torch.double,
223 torch.bfloat16,
224 # complex
225 torch.complex64,
226 torch.complex128,
227 # ints
228 torch.int,
229 torch.int8,
230 torch.int16,
231 torch.int32,
232 torch.int64,
233 torch.long,
234 torch.short,
235 # simplest
236 torch.uint8,
237 torch.bool,
238 ],
239 *[
240 # ----------
241 # numpy
242 # ----------
243 # floats
244 np.float16,
245 np.float32,
246 np.float64,
247 np.half,
248 np.single,
249 np.double,
250 # complex
251 np.complex64,
252 np.complex128,
253 # ints
254 np.int8,
255 np.int16,
256 np.int32,
257 np.int64,
258 np.longlong,
259 np.short,
260 # simplest
261 np.uint8,
262 np.bool_,
263 ],
264]
265"list of all the python, numpy, and torch numerical types I could think of"
267# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0
268try:
269 DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
270except AttributeError:
271 pass # numpy 2.0+ removed these deprecated aliases
273DTYPE_MAP: dict[str, Any] = {
274 **{str(x): x for x in DTYPE_LIST},
275 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
276}
277"mapping from string representations of types to their type"
279TORCH_DTYPE_MAP: dict[str, torch.dtype] = {
280 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
281}
282"mapping from string representations of types to specifically torch types"
284# no idea why we have to do this, smh
285DTYPE_MAP["bool"] = np.bool_
286TORCH_DTYPE_MAP["bool"] = torch.bool
289TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
290 "Adagrad": torch.optim.Adagrad,
291 "Adam": torch.optim.Adam,
292 "AdamW": torch.optim.AdamW,
293 "SparseAdam": torch.optim.SparseAdam,
294 "Adamax": torch.optim.Adamax,
295 "ASGD": torch.optim.ASGD,
296 "LBFGS": torch.optim.LBFGS,
297 "NAdam": torch.optim.NAdam,
298 "RAdam": torch.optim.RAdam,
299 "RMSprop": torch.optim.RMSprop,
300 "Rprop": torch.optim.Rprop,
301 "SGD": torch.optim.SGD,
302}
305def pad_tensor(
306 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821
307 padded_length: int,
308 pad_value: float = 0.0,
309 rpad: bool = False,
310) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821
311 """pad a 1-d tensor on the left with pad_value to length `padded_length`
313 set `rpad = True` to pad on the right instead"""
315 temp: list[torch.Tensor] = [
316 torch.full(
317 (padded_length - tensor.shape[0],),
318 pad_value,
319 dtype=tensor.dtype,
320 device=tensor.device,
321 ),
322 tensor,
323 ]
325 if rpad:
326 temp.reverse()
328 return torch.cat(temp)
331def lpad_tensor(
332 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
333) -> torch.Tensor:
334 """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
335 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
338def rpad_tensor(
339 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
340) -> torch.Tensor:
341 """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
342 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
345def pad_array(
346 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821
347 padded_length: int,
348 pad_value: float = 0.0,
349 rpad: bool = False,
350) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821
351 """pad a 1-d array on the left with pad_value to length `padded_length`
353 set `rpad = True` to pad on the right instead"""
355 temp: list[np.ndarray] = [
356 np.full(
357 (padded_length - array.shape[0],),
358 pad_value,
359 dtype=array.dtype,
360 ),
361 array,
362 ]
364 if rpad:
365 temp.reverse()
367 return np.concatenate(temp)
370def lpad_array(
371 array: np.ndarray, padded_length: int, pad_value: float = 0.0
372) -> np.ndarray:
373 """pad a 1-d array on the left with pad_value to length `padded_length`"""
374 return pad_array(array, padded_length, pad_value, rpad=False)
377def rpad_array(
378 array: np.ndarray, pad_length: int, pad_value: float = 0.0
379) -> np.ndarray:
380 """pad a 1-d array on the right with pad_value to length `pad_length`"""
381 return pad_array(array, pad_length, pad_value, rpad=True)
384def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
385 """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
386 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
389def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
390 """printable version of get_dict_shapes"""
391 return json.dumps(
392 dotlist_to_nested_dict(
393 {
394 k: str(
395 tuple(v.shape)
396 ) # to string, since indent wont play nice with tuples
397 for k, v in d.items()
398 }
399 ),
400 indent=2,
401 )
404class StateDictCompareError(AssertionError):
405 """raised when state dicts don't match"""
407 pass
410class StateDictKeysError(StateDictCompareError):
411 """raised when state dict keys don't match"""
413 pass
416class StateDictShapeError(StateDictCompareError):
417 """raised when state dict shapes don't match"""
419 pass
422class StateDictValueError(StateDictCompareError):
423 """raised when state dict values don't match"""
425 pass
428def compare_state_dicts(
429 d1: dict[str, Any],
430 d2: dict[str, Any],
431 rtol: float = 1e-5,
432 atol: float = 1e-8,
433 verbose: bool = True,
434) -> None:
435 """compare two dicts of tensors
437 # Parameters:
439 - `d1 : dict`
440 - `d2 : dict`
441 - `rtol : float`
442 (defaults to `1e-5`)
443 - `atol : float`
444 (defaults to `1e-8`)
445 - `verbose : bool`
446 (defaults to `True`)
448 # Raises:
450 - `StateDictKeysError` : keys don't match
451 - `StateDictShapeError` : shapes don't match (but keys do)
452 - `StateDictValueError` : values don't match (but keys and shapes do)
453 """
454 # check keys match
455 d1_keys: set[str] = set(d1.keys())
456 d2_keys: set[str] = set(d2.keys())
457 symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys)
458 keys_diff_1: set[str] = d1_keys - d2_keys
459 keys_diff_2: set[str] = d2_keys - d1_keys
460 # sort sets for easier debugging
461 symmetric_diff = set(sorted(symmetric_diff))
462 keys_diff_1 = set(sorted(keys_diff_1))
463 keys_diff_2 = set(sorted(keys_diff_2))
464 diff_shapes_1: str = (
465 string_dict_shapes({k: d1[k] for k in keys_diff_1})
466 if verbose
467 else "(verbose = False)"
468 )
469 diff_shapes_2: str = (
470 string_dict_shapes({k: d2[k] for k in keys_diff_2})
471 if verbose
472 else "(verbose = False)"
473 )
474 if not len(symmetric_diff) == 0:
475 raise StateDictKeysError(
476 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
477 )
479 # check tensors match
480 shape_failed: list[str] = list()
481 vals_failed: list[str] = list()
482 for k, v1 in d1.items():
483 v2 = d2[k]
484 # check shapes first
485 if not v1.shape == v2.shape:
486 shape_failed.append(k)
487 else:
488 # if shapes match, check values
489 if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
490 vals_failed.append(k)
492 str_shape_failed: str = (
493 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
494 )
495 str_vals_failed: str = (
496 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
497 )
499 if not len(shape_failed) == 0:
500 raise StateDictShapeError(
501 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
502 )
503 if not len(vals_failed) == 0:
504 raise StateDictValueError(
505 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
506 )