Coverage for tests/unit/test_tensor_utils_torch.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-28 17:24 +0000

1from __future__ import annotations 

2 

3import jaxtyping 

4import numpy as np 

5import pytest 

6import torch 

7 

8from muutils.tensor_utils import ( 

9 DTYPE_MAP, 

10 TORCH_DTYPE_MAP, 

11 StateDictKeysError, 

12 StateDictShapeError, 

13 compare_state_dicts, 

14 get_dict_shapes, 

15 jaxtype_factory, 

16 lpad_tensor, 

17 numpy_to_torch_dtype, 

18 pad_tensor, 

19 rpad_tensor, 

20 lpad_array, 

21 pad_array, 

22 rpad_array, 

23) 

24 

25 

26def test_pad_array(): 

27 array = np.array([1, 2, 3]) 

28 assert np.array_equal(pad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

29 assert np.array_equal(lpad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

30 assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0])) 

31 

32 

33def test_jaxtype_factory(): 

34 ATensor = jaxtype_factory( 

35 "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore" 

36 ) 

37 assert ATensor.__name__ == "ATensor" 

38 assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__ # type: ignore[operator] 

39 assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__ # type: ignore[operator] 

40 

41 x = ATensor[(1, 2, 3), np.float32] # type: ignore[index] 

42 print(x) 

43 y = ATensor["dim1 dim2", np.float32] # type: ignore[index] 

44 print(y) 

45 

46 

47def test_numpy_to_torch_dtype(): 

48 # TODO: type ignores here should not be necessary? 

49 assert numpy_to_torch_dtype(np.float32) == torch.float32 # type: ignore[arg-type] 

50 assert numpy_to_torch_dtype(np.int32) == torch.int32 # type: ignore[arg-type] 

51 assert numpy_to_torch_dtype(torch.float32) == torch.float32 

52 

53 

54def test_dtype_maps(): 

55 assert len(DTYPE_MAP) == len(TORCH_DTYPE_MAP) 

56 for key in DTYPE_MAP: 

57 assert key in TORCH_DTYPE_MAP 

58 assert numpy_to_torch_dtype(DTYPE_MAP[key]) == TORCH_DTYPE_MAP[key] 

59 

60 

61def test_pad_tensor(): 

62 tensor = torch.tensor([1, 2, 3]) 

63 assert torch.all(pad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

64 assert torch.all(lpad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

65 assert torch.all(rpad_tensor(tensor, 5) == torch.tensor([1, 2, 3, 0, 0])) 

66 

67 

68def test_compare_state_dicts(): 

69 d1 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

70 d2 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

71 compare_state_dicts(d1, d2) # This should not raise an exception 

72 

73 d2["a"] = torch.tensor([7, 8, 9]) 

74 with pytest.raises(AssertionError): 

75 compare_state_dicts(d1, d2) # This should raise an exception 

76 

77 d2["a"] = torch.tensor([7, 8, 9, 10]) 

78 with pytest.raises(StateDictShapeError): 

79 compare_state_dicts(d1, d2) # This should raise an exception 

80 

81 d2["c"] = torch.tensor([10, 11, 12]) 

82 with pytest.raises(StateDictKeysError): 

83 compare_state_dicts(d1, d2) # This should raise an exception 

84 

85 

86def test_get_dict_shapes(): 

87 x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)} 

88 x_shapes = get_dict_shapes(x) 

89 assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}