Coverage for tests / unit / test_tensor_utils_torch.py: 100%
41 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1from __future__ import annotations
3import numpy as np
4import pytest
5import torch
7from muutils.tensor_utils import (
8 DTYPE_MAP,
9 TORCH_DTYPE_MAP,
10 StateDictKeysError,
11 StateDictShapeError,
12 compare_state_dicts,
13 get_dict_shapes,
14 # jaxtype_factory,
15 lpad_tensor,
16 numpy_to_torch_dtype,
17 pad_tensor,
18 rpad_tensor,
19 lpad_array,
20 pad_array,
21 rpad_array,
22)
25def test_pad_array():
26 array = np.array([1, 2, 3])
27 assert np.array_equal(pad_array(array, 5), np.array([0, 0, 1, 2, 3]))
28 assert np.array_equal(lpad_array(array, 5), np.array([0, 0, 1, 2, 3]))
29 assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0]))
32# def test_jaxtype_factory():
33# ATensor = jaxtype_factory(
34# "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore"
35# )
36# assert ATensor.__name__ == "ATensor"
37# assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__ # type: ignore[operator]
38# assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__ # type: ignore[operator]
40# x = ATensor[(1, 2, 3), np.float32] # type: ignore[index]
41# print(x)
42# y = ATensor["dim1 dim2", np.float32] # type: ignore[index]
43# print(y)
46def test_numpy_to_torch_dtype():
47 # TODO: type ignores here should not be necessary?
48 assert numpy_to_torch_dtype(np.float32) == torch.float32 # type: ignore[arg-type]
49 assert numpy_to_torch_dtype(np.int32) == torch.int32 # type: ignore[arg-type]
50 assert numpy_to_torch_dtype(torch.float32) == torch.float32
53def test_dtype_maps():
54 assert len(DTYPE_MAP) == len(TORCH_DTYPE_MAP)
55 for key in DTYPE_MAP:
56 assert key in TORCH_DTYPE_MAP
57 assert numpy_to_torch_dtype(DTYPE_MAP[key]) == TORCH_DTYPE_MAP[key]
60def test_pad_tensor():
61 tensor = torch.tensor([1, 2, 3])
62 assert torch.all(pad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3]))
63 assert torch.all(lpad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3]))
64 assert torch.all(rpad_tensor(tensor, 5) == torch.tensor([1, 2, 3, 0, 0]))
67def test_compare_state_dicts():
68 d1 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])}
69 d2 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])}
70 compare_state_dicts(d1, d2) # This should not raise an exception
72 d2["a"] = torch.tensor([7, 8, 9])
73 with pytest.raises(AssertionError):
74 compare_state_dicts(d1, d2) # This should raise an exception
76 d2["a"] = torch.tensor([7, 8, 9, 10])
77 with pytest.raises(StateDictShapeError):
78 compare_state_dicts(d1, d2) # This should raise an exception
80 d2["c"] = torch.tensor([10, 11, 12])
81 with pytest.raises(StateDictKeysError):
82 compare_state_dicts(d1, d2) # This should raise an exception
85def test_get_dict_shapes():
86 x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)}
87 x_shapes = get_dict_shapes(x)
88 assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}