Coverage for tests/unit/test_tensor_utils_torch.py: 100%
51 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-28 17:24 +0000
1from __future__ import annotations
3import jaxtyping
4import numpy as np
5import pytest
6import torch
8from muutils.tensor_utils import (
9 DTYPE_MAP,
10 TORCH_DTYPE_MAP,
11 StateDictKeysError,
12 StateDictShapeError,
13 compare_state_dicts,
14 get_dict_shapes,
15 jaxtype_factory,
16 lpad_tensor,
17 numpy_to_torch_dtype,
18 pad_tensor,
19 rpad_tensor,
20 lpad_array,
21 pad_array,
22 rpad_array,
23)
26def test_pad_array():
27 array = np.array([1, 2, 3])
28 assert np.array_equal(pad_array(array, 5), np.array([0, 0, 1, 2, 3]))
29 assert np.array_equal(lpad_array(array, 5), np.array([0, 0, 1, 2, 3]))
30 assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0]))
33def test_jaxtype_factory():
34 ATensor = jaxtype_factory(
35 "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore"
36 )
37 assert ATensor.__name__ == "ATensor"
38 assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__ # type: ignore[operator]
39 assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__ # type: ignore[operator]
41 x = ATensor[(1, 2, 3), np.float32] # type: ignore[index]
42 print(x)
43 y = ATensor["dim1 dim2", np.float32] # type: ignore[index]
44 print(y)
47def test_numpy_to_torch_dtype():
48 # TODO: type ignores here should not be necessary?
49 assert numpy_to_torch_dtype(np.float32) == torch.float32 # type: ignore[arg-type]
50 assert numpy_to_torch_dtype(np.int32) == torch.int32 # type: ignore[arg-type]
51 assert numpy_to_torch_dtype(torch.float32) == torch.float32
54def test_dtype_maps():
55 assert len(DTYPE_MAP) == len(TORCH_DTYPE_MAP)
56 for key in DTYPE_MAP:
57 assert key in TORCH_DTYPE_MAP
58 assert numpy_to_torch_dtype(DTYPE_MAP[key]) == TORCH_DTYPE_MAP[key]
61def test_pad_tensor():
62 tensor = torch.tensor([1, 2, 3])
63 assert torch.all(pad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3]))
64 assert torch.all(lpad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3]))
65 assert torch.all(rpad_tensor(tensor, 5) == torch.tensor([1, 2, 3, 0, 0]))
68def test_compare_state_dicts():
69 d1 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])}
70 d2 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])}
71 compare_state_dicts(d1, d2) # This should not raise an exception
73 d2["a"] = torch.tensor([7, 8, 9])
74 with pytest.raises(AssertionError):
75 compare_state_dicts(d1, d2) # This should raise an exception
77 d2["a"] = torch.tensor([7, 8, 9, 10])
78 with pytest.raises(StateDictShapeError):
79 compare_state_dicts(d1, d2) # This should raise an exception
81 d2["c"] = torch.tensor([10, 11, 12])
82 with pytest.raises(StateDictKeysError):
83 compare_state_dicts(d1, d2) # This should raise an exception
86def test_get_dict_shapes():
87 x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)}
88 x_shapes = get_dict_shapes(x)
89 assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}