Coverage for tests/unit/json_serialize/serializable_dataclass/test_helpers.py: 100%

102 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4 

5import numpy as np 

6import torch 

7 

8from muutils.json_serialize.serializable_dataclass import array_safe_eq, dc_eq 

9 

10 

11def test_array_safe_eq(): 

12 assert array_safe_eq(np.array([1, 2, 3]), np.array([1, 2, 3])) 

13 assert not array_safe_eq(np.array([1, 2, 3]), np.array([4, 5, 6])) 

14 assert array_safe_eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])) 

15 assert not array_safe_eq(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])) 

16 assert array_safe_eq(np.array([]), np.array([])) 

17 assert array_safe_eq(np.array([[]]), np.array([[]])) 

18 assert array_safe_eq([], []) 

19 assert array_safe_eq(dict(), dict()) 

20 assert array_safe_eq([1, 2, 3], [1, 2, 3]) 

21 assert array_safe_eq([np.array([1, 2, 3])], [np.array([1, 2, 3])]) 

22 assert not array_safe_eq([], [np.array([1, 2, 3])]) 

23 assert not array_safe_eq([[np.array([1, 2, 3])]], [np.array([1, 2, 3])]) 

24 assert array_safe_eq( 

25 [np.array([1, 2, 3]), torch.tensor([1, 2, 3])], 

26 [np.array([1, 2, 3]), torch.tensor([1, 2, 3])], 

27 ) 

28 assert array_safe_eq([np.array([1, 2, 3]), []], [np.array([1, 2, 3]), []]) 

29 assert not array_safe_eq([[], np.array([1, 2, 3])], [np.array([1, 2, 3]), []]) 

30 

31 

32def test_dc_eq_case1(): 

33 @dataclass(eq=False) 

34 class TestClass: 

35 a: int 

36 b: np.ndarray 

37 c: torch.Tensor 

38 e: list[int] 

39 f: dict[str, int] 

40 

41 instance1 = TestClass( 

42 a=1, 

43 b=np.array([1, 2, 3]), 

44 c=torch.tensor([1, 2, 3]), 

45 e=[1, 2, 3], 

46 f={"key1": 1, "key2": 2}, 

47 ) 

48 

49 instance2 = TestClass( 

50 a=1, 

51 b=np.array([1, 2, 3]), 

52 c=torch.tensor([1, 2, 3]), 

53 e=[1, 2, 3], 

54 f={"key1": 1, "key2": 2}, 

55 ) 

56 

57 assert dc_eq(instance1, instance2) 

58 

59 

60def test_dc_eq_case2(): 

61 @dataclass(eq=False) 

62 class TestClass: 

63 a: int 

64 b: np.ndarray 

65 c: torch.Tensor 

66 e: list[int] 

67 f: dict[str, int] 

68 

69 instance1 = TestClass( 

70 a=1, 

71 b=np.array([1, 2, 3]), 

72 c=torch.tensor([1, 2, 3]), 

73 e=[1, 2, 3], 

74 f={"key1": 1, "key2": 2}, 

75 ) 

76 

77 instance2 = TestClass( 

78 a=1, 

79 b=np.array([4, 5, 6]), 

80 c=torch.tensor([1, 2, 3]), 

81 e=[1, 2, 3], 

82 f={"key1": 1, "key2": 2}, 

83 ) 

84 

85 assert not dc_eq(instance1, instance2) 

86 

87 

88def test_dc_eq_case3(): 

89 @dataclass(eq=False) 

90 class TestClass: 

91 a: int 

92 b: np.ndarray 

93 c: torch.Tensor 

94 e: list[int] 

95 f: dict[str, int] 

96 

97 instance1 = TestClass( 

98 a=1, 

99 b=np.array([1, 2, 3]), 

100 c=torch.tensor([1, 2, 3]), 

101 e=[1, 2, 3], 

102 f={"key1": 1, "key2": 2}, 

103 ) 

104 

105 instance2 = TestClass( 

106 a=2, 

107 b=np.array([1, 2, 3]), 

108 c=torch.tensor([1, 2, 3]), 

109 e=[1, 2, 3], 

110 f={"key1": 1, "key2": 2}, 

111 ) 

112 

113 assert not dc_eq(instance1, instance2) 

114 

115 

116def test_dc_eq_case4(): 

117 @dataclass(eq=False) 

118 class TestClass: 

119 a: int 

120 b: np.ndarray 

121 c: torch.Tensor 

122 e: list[int] 

123 f: dict[str, int] 

124 

125 @dataclass(eq=False) 

126 class TestClass2: 

127 a: int 

128 b: np.ndarray 

129 c: torch.Tensor 

130 e: list[int] 

131 f: dict[str, int] 

132 

133 instance1 = TestClass( 

134 a=1, 

135 b=np.array([1, 2, 3]), 

136 c=torch.tensor([1, 2, 3]), 

137 e=[1, 2, 3], 

138 f={"key1": 1, "key2": 2}, 

139 ) 

140 

141 instance2 = TestClass2( 

142 a=1, 

143 b=np.array([1, 2, 3]), 

144 c=torch.tensor([1, 2, 3]), 

145 e=[1, 2, 3], 

146 f={"key1": 1, "key2": 2}, 

147 ) 

148 

149 assert not dc_eq(instance1, instance2) 

150 

151 

152def test_dc_eq_case5(): 

153 @dataclass(eq=False) 

154 class TestClass: 

155 a: int 

156 

157 @dataclass(eq=False) 

158 class TestClass2: 

159 a: int 

160 

161 instance1 = TestClass(a=1) 

162 

163 instance2 = TestClass2(a=1) 

164 

165 assert not dc_eq(instance1, instance2) 

166 

167 

168def test_dc_eq_case6(): 

169 @dataclass(eq=False) 

170 class TestClass: 

171 pass 

172 

173 @dataclass(eq=False) 

174 class TestClass2: 

175 pass 

176 

177 instance1 = TestClass() 

178 

179 instance2 = TestClass2() 

180 

181 assert not dc_eq(instance1, instance2) 

182 

183 

184def test_dc_eq_case7(): 

185 @dataclass 

186 class TestClass: 

187 pass 

188 

189 @dataclass 

190 class TestClass2: 

191 pass 

192 

193 instance1 = TestClass() 

194 

195 instance2 = TestClass2() 

196 

197 assert not dc_eq(instance1, instance2)