Coverage for tests / unit / test_tensor_utils_torch.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1from __future__ import annotations 

2 

3import numpy as np 

4import pytest 

5import torch 

6 

7from muutils.tensor_utils import ( 

8 DTYPE_MAP, 

9 TORCH_DTYPE_MAP, 

10 StateDictKeysError, 

11 StateDictShapeError, 

12 compare_state_dicts, 

13 get_dict_shapes, 

14 # jaxtype_factory, 

15 lpad_tensor, 

16 numpy_to_torch_dtype, 

17 pad_tensor, 

18 rpad_tensor, 

19 lpad_array, 

20 pad_array, 

21 rpad_array, 

22) 

23 

24 

25def test_pad_array(): 

26 array = np.array([1, 2, 3]) 

27 assert np.array_equal(pad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

28 assert np.array_equal(lpad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

29 assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0])) 

30 

31 

32# def test_jaxtype_factory(): 

33# ATensor = jaxtype_factory( 

34# "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore" 

35# ) 

36# assert ATensor.__name__ == "ATensor" 

37# assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__ # type: ignore[operator] 

38# assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__ # type: ignore[operator] 

39 

40# x = ATensor[(1, 2, 3), np.float32] # type: ignore[index] 

41# print(x) 

42# y = ATensor["dim1 dim2", np.float32] # type: ignore[index] 

43# print(y) 

44 

45 

46def test_numpy_to_torch_dtype(): 

47 # TODO: type ignores here should not be necessary? 

48 assert numpy_to_torch_dtype(np.float32) == torch.float32 # type: ignore[arg-type] 

49 assert numpy_to_torch_dtype(np.int32) == torch.int32 # type: ignore[arg-type] 

50 assert numpy_to_torch_dtype(torch.float32) == torch.float32 

51 

52 

53def test_dtype_maps(): 

54 assert len(DTYPE_MAP) == len(TORCH_DTYPE_MAP) 

55 for key in DTYPE_MAP: 

56 assert key in TORCH_DTYPE_MAP 

57 assert numpy_to_torch_dtype(DTYPE_MAP[key]) == TORCH_DTYPE_MAP[key] 

58 

59 

60def test_pad_tensor(): 

61 tensor = torch.tensor([1, 2, 3]) 

62 assert torch.all(pad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

63 assert torch.all(lpad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

64 assert torch.all(rpad_tensor(tensor, 5) == torch.tensor([1, 2, 3, 0, 0])) 

65 

66 

67def test_compare_state_dicts(): 

68 d1 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

69 d2 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

70 compare_state_dicts(d1, d2) # This should not raise an exception 

71 

72 d2["a"] = torch.tensor([7, 8, 9]) 

73 with pytest.raises(AssertionError): 

74 compare_state_dicts(d1, d2) # This should raise an exception 

75 

76 d2["a"] = torch.tensor([7, 8, 9, 10]) 

77 with pytest.raises(StateDictShapeError): 

78 compare_state_dicts(d1, d2) # This should raise an exception 

79 

80 d2["c"] = torch.tensor([10, 11, 12]) 

81 with pytest.raises(StateDictKeysError): 

82 compare_state_dicts(d1, d2) # This should raise an exception 

83 

84 

85def test_get_dict_shapes(): 

86 x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)} 

87 x_shapes = get_dict_shapes(x) 

88 assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}