Coverage for muutils / json_serialize / util.py: 50%
114 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
1"""utilities for json_serialize"""
3from __future__ import annotations
5import dataclasses
6import functools
7import inspect
8import sys
9import typing
10import warnings
11from typing import Any, Callable, Iterable, TypeVar, Union
13from muutils.json_serialize.types import BaseType, Hashableitem
15if typing.TYPE_CHECKING:
16 pass
18_NUMPY_WORKING: bool
19try:
20 _NUMPY_WORKING = True
21except ImportError:
22 warnings.warn("numpy not found, cannot serialize numpy arrays!")
23 _NUMPY_WORKING = False
26# pyright: reportExplicitAny=false
28# At type-checking time, include array serialization types to avoid nominal type errors
29# This avoids superfluous imports at runtime
30# if TYPE_CHECKING:
31# from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta
33# JSONitem = Union[
34# BaseType,
35# typing.Sequence["JSONitem"],
36# typing.Dict[str, "JSONitem"],
37# SerializedArrayWithMeta,
38# NumericList,
39# ]
40# else:
42JSONitem = Union[
43 BaseType,
44 typing.Sequence["JSONitem"],
45 typing.Dict[str, "JSONitem"],
46 # TODO: figure this out
47 # "_SerializedSet",
48 # "_SerializedFrozenset",
49]
51JSONdict = typing.Dict[str, JSONitem]
54# TODO: this bit is very broken
55# or if python version <3.9
56if typing.TYPE_CHECKING or sys.version_info < (3, 9):
57 MonoTuple = typing.Sequence
58else:
60 class MonoTuple: # pyright: ignore[reportUnreachable]
61 """tuple type hint, but for a tuple of any length with all the same type"""
63 __slots__ = ()
65 def __new__(cls, *args, **kwargs):
66 raise TypeError("Type MonoTuple cannot be instantiated.")
68 def __init_subclass__(cls, *args, **kwargs):
69 raise TypeError(f"Cannot subclass {cls.__module__}")
71 # idk why mypy thinks there is no such function in typing
72 @typing._tp_cache # type: ignore
73 def __class_getitem__(cls, params):
74 if getattr(params, "__origin__", None) == typing.Union:
75 return typing.GenericAlias(tuple, (params, Ellipsis))
76 elif isinstance(params, type):
77 typing.GenericAlias(tuple, (params, Ellipsis))
78 # test if has len and is iterable
79 elif isinstance(params, Iterable):
80 if len(params) == 0:
81 return tuple
82 elif len(params) == 1:
83 return typing.GenericAlias(tuple, (params[0], Ellipsis))
84 else:
85 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")
88# TYPING: we allow `Any` here because the container is... universal
89class UniversalContainer:
90 """contains everything -- `x in UniversalContainer()` is always True"""
92 def __contains__(self, x: Any) -> bool: # pyright: ignore[reportAny]
93 return True
96def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny]
97 """checks if `x` is a `namedtuple`
99 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
100 """
101 t: type = type(x) # pyright: ignore[reportUnknownVariableType, reportAny]
102 b: tuple[type, ...] = t.__bases__
103 if len(b) != 1 or (b[0] is not tuple):
104 return False
105 f: Any = getattr(t, "_fields", None)
106 if not isinstance(f, tuple):
107 return False
108 # fine that the type is unknown -- that's what we want to check
109 return all(isinstance(n, str) for n in f) # pyright: ignore[reportUnknownVariableType]
112T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn")
115def try_catch(
116 func: Callable[..., T_FuncTryCatchReturn],
117) -> Callable[..., Union[T_FuncTryCatchReturn, str]]:
118 """wraps the function to catch exceptions, returns serialized error message on exception
120 returned func will return normal result on success, or error message on exception
121 """
123 @functools.wraps(func)
124 def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # pyright: ignore[reportAny]
125 try:
126 return func(*args, **kwargs)
127 except Exception as e:
128 return f"{e.__class__.__name__}: {e}"
130 return newfunc
133# TYPING: can we get rid of any of these?
134def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: # pyright: ignore[reportAny]
135 if isinstance(obj, typing.Mapping):
136 return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
137 elif isinstance(obj, (bool, int, float, str)):
138 return obj
139 elif isinstance(obj, (tuple, list, Iterable)):
140 return tuple(_recursive_hashify(v) for v in obj) # pyright: ignore[reportUnknownVariableType]
141 else:
142 if force:
143 return str(obj) # pyright: ignore[reportAny]
144 else:
145 raise ValueError(f"cannot hashify:\n{obj}")
148class SerializationException(Exception):
149 pass
152def string_as_lines(s: str | None) -> list[str]:
153 """for easier reading of long strings in json, split up by newlines
155 sort of like how jupyter notebooks do it
156 """
157 if s is None:
158 return list()
159 else:
160 return s.splitlines(keepends=False)
163def safe_getsource(func: Callable[..., Any]) -> list[str]:
164 try:
165 return string_as_lines(inspect.getsource(func))
166 except Exception as e:
167 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")
170# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises
171def array_safe_eq(a: Any, b: Any) -> bool: # pyright: ignore[reportAny]
172 """check if two objects are equal, account for if numpy arrays or torch tensors"""
173 if a is b:
174 return True
176 if type(a) is not type(b): # pyright: ignore[reportAny]
177 return False
179 if (
180 str(type(a)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType]
181 and str(type(b)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType]
182 ) or (
183 str(type(a)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType]
184 and str(type(b)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType]
185 ):
186 return (a == b).all() # pyright: ignore[reportAny]
188 if (
189 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny]
190 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny]
191 ):
192 return a.equals(b) # pyright: ignore[reportAny]
194 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
195 if len(a) == 0 and len(b) == 0:
196 return True
197 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))
199 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
200 return len(a) == len(b) and all( # pyright: ignore[reportUnknownArgumentType]
201 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
202 for k1, k2 in zip(a.keys(), b.keys()) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
203 )
205 try:
206 return bool(a == b) # pyright: ignore[reportAny]
207 except (TypeError, ValueError) as e:
208 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
209 return NotImplemented # type: ignore[return-value]
212# TYPING: see what can be done about so many `Any`s here
213def dc_eq(
214 dc1: Any, # pyright: ignore[reportAny]
215 dc2: Any, # pyright: ignore[reportAny]
216 except_when_class_mismatch: bool = False,
217 false_when_class_mismatch: bool = True,
218 except_when_field_mismatch: bool = False,
219) -> bool:
220 """
221 checks if two dataclasses which (might) hold numpy arrays are equal
223 # Parameters:
225 - `dc1`: the first dataclass
226 - `dc2`: the second dataclass
227 - `except_when_class_mismatch: bool`
228 if `True`, will throw `TypeError` if the classes are different.
229 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
230 (default: `False`)
231 - `false_when_class_mismatch: bool`
232 only relevant if `except_when_class_mismatch` is `False`.
233 if `True`, will return `False` if the classes are different.
234 if `False`, will attempt to compare the fields.
235 - `except_when_field_mismatch: bool`
236 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
237 if `True`, will throw `AttributeError` if the fields are different.
238 (default: `False`)
240 # Returns:
241 - `bool`: True if the dataclasses are equal, False otherwise
243 # Raises:
244 - `TypeError`: if the dataclasses are of different classes
245 - `AttributeError`: if the dataclasses have different fields
247 ```
248 [START]
249 ▼
250 ┌─────────────┐
251 │ dc1 is dc2? │───Yes───► (True)
252 └──────┬──────┘
253 │No
254 ▼
255 ┌───────────────┐
256 │ classes match?│───Yes───► [compare field values] ───► (True/False)
257 └──────┬────────┘
258 │No
259 ▼
260 ┌────────────────────────────┐
261 │ except_when_class_mismatch?│───Yes───► { raise TypeError }
262 └─────────────┬──────────────┘
263 │No
264 ▼
265 ┌────────────────────────────┐
266 │ false_when_class_mismatch? │───Yes───► (False)
267 └─────────────┬──────────────┘
268 │No
269 ▼
270 ┌────────────────────────────┐
271 │ except_when_field_mismatch?│───No────► [compare field values]
272 └─────────────┬──────────────┘
273 │Yes
274 ▼
275 ┌───────────────┐
276 │ fields match? │───Yes───► [compare field values]
277 └──────┬────────┘
278 │No
279 ▼
280 { raise AttributeError }
281 ```
283 """
284 if dc1 is dc2:
285 return True
287 if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportAny]
288 if except_when_class_mismatch:
289 # if the classes don't match, raise an error
290 raise TypeError(
291 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny]
292 )
293 if false_when_class_mismatch:
294 # return False immediately without attempting field comparison
295 return False
296 # classes don't match but we'll try to compare fields anyway
297 if except_when_field_mismatch:
298 dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny]
299 dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny]
300 fields_match: bool = set(dc1_fields) == set(dc2_fields)
301 if not fields_match:
302 # if the fields don't match, raise an error
303 raise AttributeError(
304 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
305 )
307 return all(
308 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny]
309 for fld in dataclasses.fields(dc1) # pyright: ignore[reportAny]
310 if fld.compare
311 )