Coverage for muutils/json_serialize/util.py: 43%
115 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""utilities for json_serialize"""
3from __future__ import annotations
5import dataclasses
6import functools
7import inspect
8import sys
9import typing
10import warnings
11from typing import Any, Callable, Iterable, Union
13_NUMPY_WORKING: bool
14try:
15 _NUMPY_WORKING = True
16except ImportError:
17 warnings.warn("numpy not found, cannot serialize numpy arrays!")
18 _NUMPY_WORKING = False
21BaseType = Union[
22 bool,
23 int,
24 float,
25 str,
26 None,
27]
29JSONitem = Union[
30 BaseType,
31 # mypy doesn't like recursive types, so we just go down a few levels manually
32 typing.List[Union[BaseType, typing.List[Any], typing.Dict[str, Any]]],
33 typing.Dict[str, Union[BaseType, typing.List[Any], typing.Dict[str, Any]]],
34]
35JSONdict = typing.Dict[str, JSONitem]
37Hashableitem = Union[bool, int, float, str, tuple]
40_FORMAT_KEY: str = "__muutils_format__"
41_REF_KEY: str = "$ref"
43# or if python version <3.9
44if typing.TYPE_CHECKING or sys.version_info < (3, 9):
45 MonoTuple = typing.Sequence
46else:
48 class MonoTuple:
49 """tuple type hint, but for a tuple of any length with all the same type"""
51 __slots__ = ()
53 def __new__(cls, *args, **kwargs):
54 raise TypeError("Type MonoTuple cannot be instantiated.")
56 def __init_subclass__(cls, *args, **kwargs):
57 raise TypeError(f"Cannot subclass {cls.__module__}")
59 # idk why mypy thinks there is no such function in typing
60 @typing._tp_cache # type: ignore
61 def __class_getitem__(cls, params):
62 if getattr(params, "__origin__", None) == typing.Union:
63 return typing.GenericAlias(tuple, (params, Ellipsis))
64 elif isinstance(params, type):
65 typing.GenericAlias(tuple, (params, Ellipsis))
66 # test if has len and is iterable
67 elif isinstance(params, Iterable):
68 if len(params) == 0:
69 return tuple
70 elif len(params) == 1:
71 return typing.GenericAlias(tuple, (params[0], Ellipsis))
72 else:
73 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")
76class UniversalContainer:
77 """contains everything -- `x in UniversalContainer()` is always True"""
79 def __contains__(self, x: Any) -> bool:
80 return True
83def isinstance_namedtuple(x: Any) -> bool:
84 """checks if `x` is a `namedtuple`
86 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
87 """
88 t: type = type(x)
89 b: tuple = t.__bases__
90 if len(b) != 1 or (b[0] is not tuple):
91 return False
92 f: Any = getattr(t, "_fields", None)
93 if not isinstance(f, tuple):
94 return False
95 return all(isinstance(n, str) for n in f)
98def try_catch(func: Callable):
99 """wraps the function to catch exceptions, returns serialized error message on exception
101 returned func will return normal result on success, or error message on exception
102 """
104 @functools.wraps(func)
105 def newfunc(*args, **kwargs):
106 try:
107 return func(*args, **kwargs)
108 except Exception as e:
109 return f"{e.__class__.__name__}: {e}"
111 return newfunc
114def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem:
115 if isinstance(obj, typing.Mapping):
116 return tuple((k, _recursive_hashify(v)) for k, v in obj.items())
117 elif isinstance(obj, (tuple, list, Iterable)):
118 return tuple(_recursive_hashify(v) for v in obj)
119 elif isinstance(obj, (bool, int, float, str)):
120 return obj
121 else:
122 if force:
123 return str(obj)
124 else:
125 raise ValueError(f"cannot hashify:\n{obj}")
128class SerializationException(Exception):
129 pass
132def string_as_lines(s: str | None) -> list[str]:
133 """for easier reading of long strings in json, split up by newlines
135 sort of like how jupyter notebooks do it
136 """
137 if s is None:
138 return list()
139 else:
140 return s.splitlines(keepends=False)
143def safe_getsource(func) -> list[str]:
144 try:
145 return string_as_lines(inspect.getsource(func))
146 except Exception as e:
147 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")
150# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises
151def array_safe_eq(a: Any, b: Any) -> bool:
152 """check if two objects are equal, account for if numpy arrays or torch tensors"""
153 if a is b:
154 return True
156 if type(a) is not type(b):
157 return False
159 if (
160 str(type(a)) == "<class 'numpy.ndarray'>"
161 and str(type(b)) == "<class 'numpy.ndarray'>"
162 ) or (
163 str(type(a)) == "<class 'torch.Tensor'>"
164 and str(type(b)) == "<class 'torch.Tensor'>"
165 ):
166 return (a == b).all()
168 if (
169 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>"
170 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>"
171 ):
172 return a.equals(b)
174 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
175 if len(a) == 0 and len(b) == 0:
176 return True
177 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))
179 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
180 return len(a) == len(b) and all(
181 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
182 for k1, k2 in zip(a.keys(), b.keys())
183 )
185 try:
186 return bool(a == b)
187 except (TypeError, ValueError) as e:
188 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
189 return NotImplemented # type: ignore[return-value]
192def dc_eq(
193 dc1,
194 dc2,
195 except_when_class_mismatch: bool = False,
196 false_when_class_mismatch: bool = True,
197 except_when_field_mismatch: bool = False,
198) -> bool:
199 """
200 checks if two dataclasses which (might) hold numpy arrays are equal
202 # Parameters:
204 - `dc1`: the first dataclass
205 - `dc2`: the second dataclass
206 - `except_when_class_mismatch: bool`
207 if `True`, will throw `TypeError` if the classes are different.
208 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
209 (default: `False`)
210 - `false_when_class_mismatch: bool`
211 only relevant if `except_when_class_mismatch` is `False`.
212 if `True`, will return `False` if the classes are different.
213 if `False`, will attempt to compare the fields.
214 - `except_when_field_mismatch: bool`
215 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
216 if `True`, will throw `TypeError` if the fields are different.
217 (default: `True`)
219 # Returns:
220 - `bool`: True if the dataclasses are equal, False otherwise
222 # Raises:
223 - `TypeError`: if the dataclasses are of different classes
224 - `AttributeError`: if the dataclasses have different fields
226 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
227 ```
228 [START]
229 ▼
230 ┌───────────┐ ┌─────────┐
231 │dc1 is dc2?├─►│ classes │
232 └──┬────────┘No│ match? │
233 ──── │ ├─────────┤
234 (True)◄──┘Yes │No │Yes
235 ──── ▼ ▼
236 ┌────────────────┐ ┌────────────┐
237 │ except when │ │ fields keys│
238 │ class mismatch?│ │ match? │
239 ├───────────┬────┘ ├───────┬────┘
240 │Yes │No │No │Yes
241 ▼ ▼ ▼ ▼
242 ─────────── ┌──────────┐ ┌────────┐
243 { raise } │ except │ │ field │
244 { TypeError } │ when │ │ values │
245 ─────────── │ field │ │ match? │
246 │ mismatch?│ ├────┬───┘
247 ├───────┬──┘ │ │Yes
248 │Yes │No │No ▼
249 ▼ ▼ │ ────
250 ─────────────── ───── │ (True)
251 { raise } (False)◄┘ ────
252 { AttributeError} ─────
253 ───────────────
254 ```
256 """
257 if dc1 is dc2:
258 return True
260 if dc1.__class__ is not dc2.__class__:
261 if except_when_class_mismatch:
262 # if the classes don't match, raise an error
263 raise TypeError(
264 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
265 )
266 if except_when_field_mismatch:
267 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
268 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
269 fields_match: bool = set(dc1_fields) == set(dc2_fields)
270 if not fields_match:
271 # if the fields match, keep going
272 raise AttributeError(
273 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
274 )
275 return False
277 return all(
278 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
279 for fld in dataclasses.fields(dc1)
280 if fld.compare
281 )