Coverage for tests / unit / json_serialize / test_util.py: 97%
216 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1from collections import namedtuple
2from dataclasses import dataclass, field
3from typing import NamedTuple
5import pytest
7# pyright: reportPrivateUsage=false
9# Module code assumed to be imported from my_module
10from muutils.json_serialize.types import _FORMAT_KEY
11from muutils.json_serialize.util import (
12 UniversalContainer,
13 _recursive_hashify,
14 array_safe_eq,
15 dc_eq,
16 isinstance_namedtuple,
17 safe_getsource,
18 string_as_lines,
19 try_catch,
20)
23def test_universal_container():
24 uc = UniversalContainer()
25 assert "anything" in uc
26 assert 123 in uc
27 assert None in uc
30def test_isinstance_namedtuple():
31 Point = namedtuple("Point", ["x", "y"])
32 p = Point(1, 2)
33 assert isinstance_namedtuple(p)
34 assert not isinstance_namedtuple((1, 2))
36 class Point2(NamedTuple):
37 x: int
38 y: int
40 p2 = Point2(1, 2)
41 assert isinstance_namedtuple(p2)
44def test_try_catch():
45 @try_catch
46 def raises_value_error():
47 raise ValueError("test error")
49 @try_catch
50 def normal_func(x):
51 return x
53 assert raises_value_error() == "ValueError: test error"
54 assert normal_func(10) == 10
57def test_recursive_hashify():
58 assert _recursive_hashify({"a": [1, 2, 3]}) == (("a", (1, 2, 3)),)
59 assert _recursive_hashify([1, 2, 3]) == (1, 2, 3)
60 assert _recursive_hashify(123) == 123
61 with pytest.raises(ValueError):
62 _recursive_hashify(object(), force=False)
65def test_string_as_lines():
66 assert string_as_lines("line1\nline2\nline3") == ["line1", "line2", "line3"]
67 assert string_as_lines(None) == []
70def test_safe_getsource():
71 def sample_func():
72 pass
74 source = safe_getsource(sample_func)
75 print(f"Source of sample_func: {source}")
76 assert "def sample_func():" in source[0]
78 def raises_error():
79 raise Exception("test error")
81 wrapped_func = try_catch(raises_error)
82 error_source = safe_getsource(wrapped_func)
83 print(f"Source of wrapped_func: {error_source}")
84 # Check for the original function's source since the decorator doesn't change this
85 assert any("def raises_error():" in line for line in error_source)
88# Additional tests from TODO.md
91def test_try_catch_exception_handling():
92 """Test that try_catch properly catches exceptions and returns default error message."""
94 @try_catch
95 def raises_runtime_error():
96 raise RuntimeError("runtime error message")
98 @try_catch
99 def raises_key_error():
100 raise KeyError("missing key")
102 @try_catch
103 def raises_zero_division():
104 return 1 / 0
106 # Test that exceptions are caught and serialized
107 assert raises_runtime_error() == "RuntimeError: runtime error message"
108 assert raises_key_error() == "KeyError: 'missing key'"
109 result = raises_zero_division()
110 assert "ZeroDivisionError" in result # pyright: ignore[reportOperatorIssue]
112 # Test with arguments
113 @try_catch
114 def func_with_args(a, b):
115 if a == 0:
116 raise ValueError(f"a cannot be 0, got {a}")
117 return a + b
119 assert func_with_args(1, 2) == 3
120 assert func_with_args(0, 2) == "ValueError: a cannot be 0, got 0"
123def test_array_safe_eq():
124 """Test array_safe_eq with numpy arrays, torch tensors, and nested arrays."""
125 # Basic types
126 assert array_safe_eq(1, 1) is True
127 assert array_safe_eq(1, 2) is False
128 # Note: strings are treated as sequences by array_safe_eq, so we test differently
129 assert array_safe_eq(1.5, 1.5) is True
130 assert array_safe_eq(True, True) is True
132 # Lists and sequences
133 assert array_safe_eq([1, 2, 3], [1, 2, 3]) is True
134 assert array_safe_eq([1, 2, 3], [1, 2, 4]) is False
135 assert array_safe_eq([], []) is True
136 assert array_safe_eq((1, 2, 3), (1, 2, 3)) is True
138 # Nested arrays
139 assert array_safe_eq([[1, 2], [3, 4]], [[1, 2], [3, 4]]) is True
140 assert array_safe_eq([[1, 2], [3, 4]], [[1, 2], [3, 5]]) is False
141 assert array_safe_eq([[[1]], [[2]]], [[[1]], [[2]]]) is True
143 # Dicts
144 assert array_safe_eq({"a": 1, "b": 2}, {"a": 1, "b": 2}) is True
145 assert array_safe_eq({"a": 1, "b": 2}, {"a": 1, "b": 3}) is False
146 assert array_safe_eq({}, {}) is True
148 # Mixed nested structures
149 assert (
150 array_safe_eq({"a": [1, 2], "b": {"c": 3}}, {"a": [1, 2], "b": {"c": 3}})
151 is True
152 )
153 assert (
154 array_safe_eq({"a": [1, 2], "b": {"c": 3}}, {"a": [1, 2], "b": {"c": 4}})
155 is False
156 )
158 # Identity check
159 obj = {"a": 1}
160 assert array_safe_eq(obj, obj) is True
162 # Type mismatch
163 assert array_safe_eq(1, 1.0) is False # Different types
164 assert array_safe_eq([1, 2], (1, 2)) is False
166 # Try with numpy if available (note: numpy returns np.True_ not Python True)
167 try:
168 import numpy as np
170 arr1 = np.array([1, 2, 3])
171 arr2 = np.array([1, 2, 3])
172 arr3 = np.array([1, 2, 4])
173 assert array_safe_eq(arr1, arr2) # Use == not is for numpy bool
174 assert not array_safe_eq(arr1, arr3)
175 except ImportError:
176 pass # Skip numpy tests if not available
178 # Try with torch if available (note: torch also may return tensor bool)
179 try:
180 import torch
182 t1 = torch.tensor([1.0, 2.0, 3.0])
183 t2 = torch.tensor([1.0, 2.0, 3.0])
184 t3 = torch.tensor([1.0, 2.0, 4.0])
185 assert array_safe_eq(t1, t2) # Use == not is for torch bool
186 assert not array_safe_eq(t1, t3)
187 except ImportError:
188 pass # Skip torch tests if not available
191def test_dc_eq():
192 """Test dc_eq for dataclasses equal and unequal cases."""
194 @dataclass
195 class Point:
196 x: int
197 y: int
199 @dataclass
200 class Point3D:
201 x: int
202 y: int
203 z: int
205 @dataclass
206 class PointWithArray:
207 x: int
208 coords: list
210 # Equal dataclasses
211 p1 = Point(1, 2)
212 p2 = Point(1, 2)
213 assert dc_eq(p1, p2) is True
215 # Unequal dataclasses
216 p3 = Point(1, 3)
217 assert dc_eq(p1, p3) is False
219 # Identity
220 assert dc_eq(p1, p1) is True
222 # Different classes - default behavior (false_when_class_mismatch=True)
223 p3d = Point3D(1, 2, 3)
224 assert dc_eq(p1, p3d) is False
226 # Different classes - except_when_class_mismatch=True
227 with pytest.raises(
228 TypeError, match="Cannot compare dataclasses of different classes"
229 ):
230 dc_eq(p1, p3d, except_when_class_mismatch=True)
232 # Dataclasses with arrays
233 pa1 = PointWithArray(1, [1, 2, 3])
234 pa2 = PointWithArray(1, [1, 2, 3])
235 pa3 = PointWithArray(1, [1, 2, 4])
236 assert dc_eq(pa1, pa2) is True
237 assert dc_eq(pa1, pa3) is False
239 # Test with nested structures
240 @dataclass
241 class Container:
242 items: list
243 metadata: dict
245 c1 = Container([1, 2, 3], {"name": "test"})
246 c2 = Container([1, 2, 3], {"name": "test"})
247 c3 = Container([1, 2, 3], {"name": "other"})
248 assert dc_eq(c1, c2) is True
249 assert dc_eq(c1, c3) is False
251 # Test except_when_field_mismatch with different classes and different fields
252 # Must set false_when_class_mismatch=False to reach the field check
253 with pytest.raises(AttributeError, match="different fields"):
254 dc_eq(p1, p3d, except_when_field_mismatch=True, false_when_class_mismatch=False)
256 # Test except_when_field_mismatch with different classes but SAME fields - should NOT raise
257 @dataclass
258 class Point2D:
259 x: int
260 y: int
262 p2d = Point2D(1, 2)
263 # Same fields, different classes, same values - should return True
264 result = dc_eq(
265 p1, p2d, except_when_field_mismatch=True, false_when_class_mismatch=False
266 )
267 assert result is True
269 # Different classes, same fields, different values - should return False
270 p2d_diff = Point2D(1, 99)
271 assert (
272 dc_eq(
273 p2d_diff,
274 p1,
275 false_when_class_mismatch=False,
276 except_when_field_mismatch=True,
277 )
278 is False
279 )
281 # Test parameter precedence: except_when_class_mismatch takes precedence over false_when_class_mismatch
282 with pytest.raises(
283 TypeError, match="Cannot compare dataclasses of different classes"
284 ):
285 dc_eq(p1, p3d, except_when_class_mismatch=True, false_when_class_mismatch=True)
287 # Test parameter precedence: except_when_class_mismatch takes precedence over except_when_field_mismatch
288 with pytest.raises(
289 TypeError, match="Cannot compare dataclasses of different classes"
290 ):
291 dc_eq(p1, p3d, except_when_class_mismatch=True, except_when_field_mismatch=True)
293 # Test with empty dataclasses
294 @dataclass
295 class Empty:
296 pass
298 @dataclass
299 class AlsoEmpty:
300 pass
302 e1, e2 = Empty(), Empty()
303 assert dc_eq(e1, e2) is True
305 # Different empty classes - same fields (none), should be equal when allowing cross-class comparison
306 ae = AlsoEmpty()
307 assert dc_eq(e1, ae, false_when_class_mismatch=False) is True
309 # Test with compare=False fields - these should be ignored in comparison
310 @dataclass
311 class WithIgnored:
312 x: int
313 ignored: int = field(compare=False)
315 w1 = WithIgnored(1, 100)
316 w2 = WithIgnored(1, 999) # ignored field differs
317 assert (
318 dc_eq(w1, w2) is True
319 ) # Should still be equal since ignored field is not compared
321 # Test with non-dataclass objects - should raise TypeError
322 class NotADataclass:
323 def __init__(self, x: int):
324 self.x = x
326 with pytest.raises(TypeError):
327 dc_eq(NotADataclass(1), NotADataclass(1))
330def test_FORMAT_KEY():
331 """Test that FORMAT_KEY constant is accessible and has expected value."""
332 # Test that the format key exists and is a string
333 assert isinstance(_FORMAT_KEY, str)
334 assert _FORMAT_KEY == "__muutils_format__"
336 # Test that it can be used in dictionaries (common use case)
337 data = {_FORMAT_KEY: "custom_type", "value": 42}
338 assert data[_FORMAT_KEY] == "custom_type"
339 assert _FORMAT_KEY in data
342def test_edge_cases():
343 """Test edge cases for utility functions: None values, empty containers, mixed types."""
344 # string_as_lines with None
345 assert string_as_lines(None) == []
346 # Empty string splits to empty list (splitlines behavior)
347 assert string_as_lines("") == []
348 assert string_as_lines("single") == ["single"]
350 # _recursive_hashify with empty containers
351 assert _recursive_hashify([]) == ()
352 assert _recursive_hashify({}) == ()
353 assert _recursive_hashify(()) == ()
355 # _recursive_hashify with mixed nested types
356 mixed = {"list": [1, 2], "dict": {"nested": True}, "tuple": (3, 4)}
357 result = _recursive_hashify(mixed)
358 assert isinstance(result, tuple)
360 # array_safe_eq with empty containers
361 assert array_safe_eq([], []) is True
362 assert array_safe_eq({}, {}) is True
363 assert array_safe_eq((), ()) is True
365 # array_safe_eq with None
366 assert array_safe_eq(None, None) is True
367 assert array_safe_eq(None, 0) is False
369 # try_catch with function returning None
370 @try_catch
371 def returns_none():
372 return None
374 assert returns_none() is None
376 # UniversalContainer with various types
377 uc = UniversalContainer()
378 assert None in uc
379 assert [] in uc
380 assert {} in uc
381 assert object() in uc