Coverage for tests/unit/test_tensor_utils.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from __future__ import annotations 

2 

3import jaxtyping 

4import numpy as np 

5import pytest 

6import torch 

7 

8from muutils.tensor_utils import ( 

9 DTYPE_MAP, 

10 TORCH_DTYPE_MAP, 

11 StateDictKeysError, 

12 StateDictShapeError, 

13 compare_state_dicts, 

14 get_dict_shapes, 

15 jaxtype_factory, 

16 lpad_array, 

17 lpad_tensor, 

18 numpy_to_torch_dtype, 

19 pad_array, 

20 pad_tensor, 

21 rpad_array, 

22 rpad_tensor, 

23) 

24 

25 

26def test_jaxtype_factory(): 

27 ATensor = jaxtype_factory( 

28 "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore" 

29 ) 

30 assert ATensor.__name__ == "ATensor" 

31 assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__ # type: ignore[operator] 

32 assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__ # type: ignore[operator] 

33 

34 x = ATensor[(1, 2, 3), np.float32] # type: ignore[index] 

35 print(x) 

36 y = ATensor["dim1 dim2", np.float32] # type: ignore[index] 

37 print(y) 

38 

39 

40def test_numpy_to_torch_dtype(): 

41 # TODO: type ignores here should not be necessary? 

42 assert numpy_to_torch_dtype(np.float32) == torch.float32 # type: ignore[arg-type] 

43 assert numpy_to_torch_dtype(np.int32) == torch.int32 # type: ignore[arg-type] 

44 assert numpy_to_torch_dtype(torch.float32) == torch.float32 

45 

46 

47def test_dtype_maps(): 

48 assert len(DTYPE_MAP) == len(TORCH_DTYPE_MAP) 

49 for key in DTYPE_MAP: 

50 assert key in TORCH_DTYPE_MAP 

51 assert numpy_to_torch_dtype(DTYPE_MAP[key]) == TORCH_DTYPE_MAP[key] 

52 

53 

54def test_pad_tensor(): 

55 tensor = torch.tensor([1, 2, 3]) 

56 assert torch.all(pad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

57 assert torch.all(lpad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

58 assert torch.all(rpad_tensor(tensor, 5) == torch.tensor([1, 2, 3, 0, 0])) 

59 

60 

61def test_pad_array(): 

62 array = np.array([1, 2, 3]) 

63 assert np.array_equal(pad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

64 assert np.array_equal(lpad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

65 assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0])) 

66 

67 

68def test_compare_state_dicts(): 

69 d1 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

70 d2 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

71 compare_state_dicts(d1, d2) # This should not raise an exception 

72 

73 d2["a"] = torch.tensor([7, 8, 9]) 

74 with pytest.raises(AssertionError): 

75 compare_state_dicts(d1, d2) # This should raise an exception 

76 

77 d2["a"] = torch.tensor([7, 8, 9, 10]) 

78 with pytest.raises(StateDictShapeError): 

79 compare_state_dicts(d1, d2) # This should raise an exception 

80 

81 d2["c"] = torch.tensor([10, 11, 12]) 

82 with pytest.raises(StateDictKeysError): 

83 compare_state_dicts(d1, d2) # This should raise an exception 

84 

85 

86def test_get_dict_shapes(): 

87 x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)} 

88 x_shapes = get_dict_shapes(x) 

89 assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}