Coverage for tests/unit/test_tensor_utils.py: 100%
51 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1from __future__ import annotations
3import jaxtyping
4import numpy as np
5import pytest
6import torch
8from muutils.tensor_utils import (
9 DTYPE_MAP,
10 TORCH_DTYPE_MAP,
11 StateDictKeysError,
12 StateDictShapeError,
13 compare_state_dicts,
14 get_dict_shapes,
15 jaxtype_factory,
16 lpad_array,
17 lpad_tensor,
18 numpy_to_torch_dtype,
19 pad_array,
20 pad_tensor,
21 rpad_array,
22 rpad_tensor,
23)
26def test_jaxtype_factory():
27 ATensor = jaxtype_factory(
28 "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore"
29 )
30 assert ATensor.__name__ == "ATensor"
31 assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__ # type: ignore[operator]
32 assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__ # type: ignore[operator]
34 x = ATensor[(1, 2, 3), np.float32] # type: ignore[index]
35 print(x)
36 y = ATensor["dim1 dim2", np.float32] # type: ignore[index]
37 print(y)
40def test_numpy_to_torch_dtype():
41 # TODO: type ignores here should not be necessary?
42 assert numpy_to_torch_dtype(np.float32) == torch.float32 # type: ignore[arg-type]
43 assert numpy_to_torch_dtype(np.int32) == torch.int32 # type: ignore[arg-type]
44 assert numpy_to_torch_dtype(torch.float32) == torch.float32
47def test_dtype_maps():
48 assert len(DTYPE_MAP) == len(TORCH_DTYPE_MAP)
49 for key in DTYPE_MAP:
50 assert key in TORCH_DTYPE_MAP
51 assert numpy_to_torch_dtype(DTYPE_MAP[key]) == TORCH_DTYPE_MAP[key]
54def test_pad_tensor():
55 tensor = torch.tensor([1, 2, 3])
56 assert torch.all(pad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3]))
57 assert torch.all(lpad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3]))
58 assert torch.all(rpad_tensor(tensor, 5) == torch.tensor([1, 2, 3, 0, 0]))
61def test_pad_array():
62 array = np.array([1, 2, 3])
63 assert np.array_equal(pad_array(array, 5), np.array([0, 0, 1, 2, 3]))
64 assert np.array_equal(lpad_array(array, 5), np.array([0, 0, 1, 2, 3]))
65 assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0]))
68def test_compare_state_dicts():
69 d1 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])}
70 d2 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])}
71 compare_state_dicts(d1, d2) # This should not raise an exception
73 d2["a"] = torch.tensor([7, 8, 9])
74 with pytest.raises(AssertionError):
75 compare_state_dicts(d1, d2) # This should raise an exception
77 d2["a"] = torch.tensor([7, 8, 9, 10])
78 with pytest.raises(StateDictShapeError):
79 compare_state_dicts(d1, d2) # This should raise an exception
81 d2["c"] = torch.tensor([10, 11, 12])
82 with pytest.raises(StateDictKeysError):
83 compare_state_dicts(d1, d2) # This should raise an exception
86def test_get_dict_shapes():
87 x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)}
88 x_shapes = get_dict_shapes(x)
89 assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}