Coverage for tests / unit / json_serialize / test_util.py: 97%

216 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1from collections import namedtuple 

2from dataclasses import dataclass, field 

3from typing import NamedTuple 

4 

5import pytest 

6 

7# pyright: reportPrivateUsage=false 

8 

9# Module code assumed to be imported from my_module 

10from muutils.json_serialize.types import _FORMAT_KEY 

11from muutils.json_serialize.util import ( 

12 UniversalContainer, 

13 _recursive_hashify, 

14 array_safe_eq, 

15 dc_eq, 

16 isinstance_namedtuple, 

17 safe_getsource, 

18 string_as_lines, 

19 try_catch, 

20) 

21 

22 

23def test_universal_container(): 

24 uc = UniversalContainer() 

25 assert "anything" in uc 

26 assert 123 in uc 

27 assert None in uc 

28 

29 

30def test_isinstance_namedtuple(): 

31 Point = namedtuple("Point", ["x", "y"]) 

32 p = Point(1, 2) 

33 assert isinstance_namedtuple(p) 

34 assert not isinstance_namedtuple((1, 2)) 

35 

36 class Point2(NamedTuple): 

37 x: int 

38 y: int 

39 

40 p2 = Point2(1, 2) 

41 assert isinstance_namedtuple(p2) 

42 

43 

44def test_try_catch(): 

45 @try_catch 

46 def raises_value_error(): 

47 raise ValueError("test error") 

48 

49 @try_catch 

50 def normal_func(x): 

51 return x 

52 

53 assert raises_value_error() == "ValueError: test error" 

54 assert normal_func(10) == 10 

55 

56 

57def test_recursive_hashify(): 

58 assert _recursive_hashify({"a": [1, 2, 3]}) == (("a", (1, 2, 3)),) 

59 assert _recursive_hashify([1, 2, 3]) == (1, 2, 3) 

60 assert _recursive_hashify(123) == 123 

61 with pytest.raises(ValueError): 

62 _recursive_hashify(object(), force=False) 

63 

64 

65def test_string_as_lines(): 

66 assert string_as_lines("line1\nline2\nline3") == ["line1", "line2", "line3"] 

67 assert string_as_lines(None) == [] 

68 

69 

70def test_safe_getsource(): 

71 def sample_func(): 

72 pass 

73 

74 source = safe_getsource(sample_func) 

75 print(f"Source of sample_func: {source}") 

76 assert "def sample_func():" in source[0] 

77 

78 def raises_error(): 

79 raise Exception("test error") 

80 

81 wrapped_func = try_catch(raises_error) 

82 error_source = safe_getsource(wrapped_func) 

83 print(f"Source of wrapped_func: {error_source}") 

84 # Check for the original function's source since the decorator doesn't change this 

85 assert any("def raises_error():" in line for line in error_source) 

86 

87 

88# Additional tests from TODO.md 

89 

90 

91def test_try_catch_exception_handling(): 

92 """Test that try_catch properly catches exceptions and returns default error message.""" 

93 

94 @try_catch 

95 def raises_runtime_error(): 

96 raise RuntimeError("runtime error message") 

97 

98 @try_catch 

99 def raises_key_error(): 

100 raise KeyError("missing key") 

101 

102 @try_catch 

103 def raises_zero_division(): 

104 return 1 / 0 

105 

106 # Test that exceptions are caught and serialized 

107 assert raises_runtime_error() == "RuntimeError: runtime error message" 

108 assert raises_key_error() == "KeyError: 'missing key'" 

109 result = raises_zero_division() 

110 assert "ZeroDivisionError" in result # pyright: ignore[reportOperatorIssue] 

111 

112 # Test with arguments 

113 @try_catch 

114 def func_with_args(a, b): 

115 if a == 0: 

116 raise ValueError(f"a cannot be 0, got {a}") 

117 return a + b 

118 

119 assert func_with_args(1, 2) == 3 

120 assert func_with_args(0, 2) == "ValueError: a cannot be 0, got 0" 

121 

122 

123def test_array_safe_eq(): 

124 """Test array_safe_eq with numpy arrays, torch tensors, and nested arrays.""" 

125 # Basic types 

126 assert array_safe_eq(1, 1) is True 

127 assert array_safe_eq(1, 2) is False 

128 # Note: strings are treated as sequences by array_safe_eq, so we test differently 

129 assert array_safe_eq(1.5, 1.5) is True 

130 assert array_safe_eq(True, True) is True 

131 

132 # Lists and sequences 

133 assert array_safe_eq([1, 2, 3], [1, 2, 3]) is True 

134 assert array_safe_eq([1, 2, 3], [1, 2, 4]) is False 

135 assert array_safe_eq([], []) is True 

136 assert array_safe_eq((1, 2, 3), (1, 2, 3)) is True 

137 

138 # Nested arrays 

139 assert array_safe_eq([[1, 2], [3, 4]], [[1, 2], [3, 4]]) is True 

140 assert array_safe_eq([[1, 2], [3, 4]], [[1, 2], [3, 5]]) is False 

141 assert array_safe_eq([[[1]], [[2]]], [[[1]], [[2]]]) is True 

142 

143 # Dicts 

144 assert array_safe_eq({"a": 1, "b": 2}, {"a": 1, "b": 2}) is True 

145 assert array_safe_eq({"a": 1, "b": 2}, {"a": 1, "b": 3}) is False 

146 assert array_safe_eq({}, {}) is True 

147 

148 # Mixed nested structures 

149 assert ( 

150 array_safe_eq({"a": [1, 2], "b": {"c": 3}}, {"a": [1, 2], "b": {"c": 3}}) 

151 is True 

152 ) 

153 assert ( 

154 array_safe_eq({"a": [1, 2], "b": {"c": 3}}, {"a": [1, 2], "b": {"c": 4}}) 

155 is False 

156 ) 

157 

158 # Identity check 

159 obj = {"a": 1} 

160 assert array_safe_eq(obj, obj) is True 

161 

162 # Type mismatch 

163 assert array_safe_eq(1, 1.0) is False # Different types 

164 assert array_safe_eq([1, 2], (1, 2)) is False 

165 

166 # Try with numpy if available (note: numpy returns np.True_ not Python True) 

167 try: 

168 import numpy as np 

169 

170 arr1 = np.array([1, 2, 3]) 

171 arr2 = np.array([1, 2, 3]) 

172 arr3 = np.array([1, 2, 4]) 

173 assert array_safe_eq(arr1, arr2) # Use == not is for numpy bool 

174 assert not array_safe_eq(arr1, arr3) 

175 except ImportError: 

176 pass # Skip numpy tests if not available 

177 

178 # Try with torch if available (note: torch also may return tensor bool) 

179 try: 

180 import torch 

181 

182 t1 = torch.tensor([1.0, 2.0, 3.0]) 

183 t2 = torch.tensor([1.0, 2.0, 3.0]) 

184 t3 = torch.tensor([1.0, 2.0, 4.0]) 

185 assert array_safe_eq(t1, t2) # Use == not is for torch bool 

186 assert not array_safe_eq(t1, t3) 

187 except ImportError: 

188 pass # Skip torch tests if not available 

189 

190 

191def test_dc_eq(): 

192 """Test dc_eq for dataclasses equal and unequal cases.""" 

193 

194 @dataclass 

195 class Point: 

196 x: int 

197 y: int 

198 

199 @dataclass 

200 class Point3D: 

201 x: int 

202 y: int 

203 z: int 

204 

205 @dataclass 

206 class PointWithArray: 

207 x: int 

208 coords: list 

209 

210 # Equal dataclasses 

211 p1 = Point(1, 2) 

212 p2 = Point(1, 2) 

213 assert dc_eq(p1, p2) is True 

214 

215 # Unequal dataclasses 

216 p3 = Point(1, 3) 

217 assert dc_eq(p1, p3) is False 

218 

219 # Identity 

220 assert dc_eq(p1, p1) is True 

221 

222 # Different classes - default behavior (false_when_class_mismatch=True) 

223 p3d = Point3D(1, 2, 3) 

224 assert dc_eq(p1, p3d) is False 

225 

226 # Different classes - except_when_class_mismatch=True 

227 with pytest.raises( 

228 TypeError, match="Cannot compare dataclasses of different classes" 

229 ): 

230 dc_eq(p1, p3d, except_when_class_mismatch=True) 

231 

232 # Dataclasses with arrays 

233 pa1 = PointWithArray(1, [1, 2, 3]) 

234 pa2 = PointWithArray(1, [1, 2, 3]) 

235 pa3 = PointWithArray(1, [1, 2, 4]) 

236 assert dc_eq(pa1, pa2) is True 

237 assert dc_eq(pa1, pa3) is False 

238 

239 # Test with nested structures 

240 @dataclass 

241 class Container: 

242 items: list 

243 metadata: dict 

244 

245 c1 = Container([1, 2, 3], {"name": "test"}) 

246 c2 = Container([1, 2, 3], {"name": "test"}) 

247 c3 = Container([1, 2, 3], {"name": "other"}) 

248 assert dc_eq(c1, c2) is True 

249 assert dc_eq(c1, c3) is False 

250 

251 # Test except_when_field_mismatch with different classes and different fields 

252 # Must set false_when_class_mismatch=False to reach the field check 

253 with pytest.raises(AttributeError, match="different fields"): 

254 dc_eq(p1, p3d, except_when_field_mismatch=True, false_when_class_mismatch=False) 

255 

256 # Test except_when_field_mismatch with different classes but SAME fields - should NOT raise 

257 @dataclass 

258 class Point2D: 

259 x: int 

260 y: int 

261 

262 p2d = Point2D(1, 2) 

263 # Same fields, different classes, same values - should return True 

264 result = dc_eq( 

265 p1, p2d, except_when_field_mismatch=True, false_when_class_mismatch=False 

266 ) 

267 assert result is True 

268 

269 # Different classes, same fields, different values - should return False 

270 p2d_diff = Point2D(1, 99) 

271 assert ( 

272 dc_eq( 

273 p2d_diff, 

274 p1, 

275 false_when_class_mismatch=False, 

276 except_when_field_mismatch=True, 

277 ) 

278 is False 

279 ) 

280 

281 # Test parameter precedence: except_when_class_mismatch takes precedence over false_when_class_mismatch 

282 with pytest.raises( 

283 TypeError, match="Cannot compare dataclasses of different classes" 

284 ): 

285 dc_eq(p1, p3d, except_when_class_mismatch=True, false_when_class_mismatch=True) 

286 

287 # Test parameter precedence: except_when_class_mismatch takes precedence over except_when_field_mismatch 

288 with pytest.raises( 

289 TypeError, match="Cannot compare dataclasses of different classes" 

290 ): 

291 dc_eq(p1, p3d, except_when_class_mismatch=True, except_when_field_mismatch=True) 

292 

293 # Test with empty dataclasses 

294 @dataclass 

295 class Empty: 

296 pass 

297 

298 @dataclass 

299 class AlsoEmpty: 

300 pass 

301 

302 e1, e2 = Empty(), Empty() 

303 assert dc_eq(e1, e2) is True 

304 

305 # Different empty classes - same fields (none), should be equal when allowing cross-class comparison 

306 ae = AlsoEmpty() 

307 assert dc_eq(e1, ae, false_when_class_mismatch=False) is True 

308 

309 # Test with compare=False fields - these should be ignored in comparison 

310 @dataclass 

311 class WithIgnored: 

312 x: int 

313 ignored: int = field(compare=False) 

314 

315 w1 = WithIgnored(1, 100) 

316 w2 = WithIgnored(1, 999) # ignored field differs 

317 assert ( 

318 dc_eq(w1, w2) is True 

319 ) # Should still be equal since ignored field is not compared 

320 

321 # Test with non-dataclass objects - should raise TypeError 

322 class NotADataclass: 

323 def __init__(self, x: int): 

324 self.x = x 

325 

326 with pytest.raises(TypeError): 

327 dc_eq(NotADataclass(1), NotADataclass(1)) 

328 

329 

330def test_FORMAT_KEY(): 

331 """Test that FORMAT_KEY constant is accessible and has expected value.""" 

332 # Test that the format key exists and is a string 

333 assert isinstance(_FORMAT_KEY, str) 

334 assert _FORMAT_KEY == "__muutils_format__" 

335 

336 # Test that it can be used in dictionaries (common use case) 

337 data = {_FORMAT_KEY: "custom_type", "value": 42} 

338 assert data[_FORMAT_KEY] == "custom_type" 

339 assert _FORMAT_KEY in data 

340 

341 

342def test_edge_cases(): 

343 """Test edge cases for utility functions: None values, empty containers, mixed types.""" 

344 # string_as_lines with None 

345 assert string_as_lines(None) == [] 

346 # Empty string splits to empty list (splitlines behavior) 

347 assert string_as_lines("") == [] 

348 assert string_as_lines("single") == ["single"] 

349 

350 # _recursive_hashify with empty containers 

351 assert _recursive_hashify([]) == () 

352 assert _recursive_hashify({}) == () 

353 assert _recursive_hashify(()) == () 

354 

355 # _recursive_hashify with mixed nested types 

356 mixed = {"list": [1, 2], "dict": {"nested": True}, "tuple": (3, 4)} 

357 result = _recursive_hashify(mixed) 

358 assert isinstance(result, tuple) 

359 

360 # array_safe_eq with empty containers 

361 assert array_safe_eq([], []) is True 

362 assert array_safe_eq({}, {}) is True 

363 assert array_safe_eq((), ()) is True 

364 

365 # array_safe_eq with None 

366 assert array_safe_eq(None, None) is True 

367 assert array_safe_eq(None, 0) is False 

368 

369 # try_catch with function returning None 

370 @try_catch 

371 def returns_none(): 

372 return None 

373 

374 assert returns_none() is None 

375 

376 # UniversalContainer with various types 

377 uc = UniversalContainer() 

378 assert None in uc 

379 assert [] in uc 

380 assert {} in uc 

381 assert object() in uc