Coverage for muutils / json_serialize / util.py: 50%

114 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1"""utilities for json_serialize""" 

2 

3from __future__ import annotations 

4 

5import dataclasses 

6import functools 

7import inspect 

8import sys 

9import typing 

10import warnings 

11from typing import Any, Callable, Iterable, TypeVar, Union 

12 

13from muutils.json_serialize.types import BaseType, Hashableitem 

14 

15if typing.TYPE_CHECKING: 

16 pass 

17 

18_NUMPY_WORKING: bool 

19try: 

20 _NUMPY_WORKING = True 

21except ImportError: 

22 warnings.warn("numpy not found, cannot serialize numpy arrays!") 

23 _NUMPY_WORKING = False 

24 

25 

26# pyright: reportExplicitAny=false 

27 

28# At type-checking time, include array serialization types to avoid nominal type errors 

29# This avoids superfluous imports at runtime 

30# if TYPE_CHECKING: 

31# from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta 

32 

33# JSONitem = Union[ 

34# BaseType, 

35# typing.Sequence["JSONitem"], 

36# typing.Dict[str, "JSONitem"], 

37# SerializedArrayWithMeta, 

38# NumericList, 

39# ] 

40# else: 

41 

42JSONitem = Union[ 

43 BaseType, 

44 typing.Sequence["JSONitem"], 

45 typing.Dict[str, "JSONitem"], 

46 # TODO: figure this out 

47 # "_SerializedSet", 

48 # "_SerializedFrozenset", 

49] 

50 

51JSONdict = typing.Dict[str, JSONitem] 

52 

53 

54# TODO: this bit is very broken 

55# or if python version <3.9 

56if typing.TYPE_CHECKING or sys.version_info < (3, 9): 

57 MonoTuple = typing.Sequence 

58else: 

59 

60 class MonoTuple: # pyright: ignore[reportUnreachable] 

61 """tuple type hint, but for a tuple of any length with all the same type""" 

62 

63 __slots__ = () 

64 

65 def __new__(cls, *args, **kwargs): 

66 raise TypeError("Type MonoTuple cannot be instantiated.") 

67 

68 def __init_subclass__(cls, *args, **kwargs): 

69 raise TypeError(f"Cannot subclass {cls.__module__}") 

70 

71 # idk why mypy thinks there is no such function in typing 

72 @typing._tp_cache # type: ignore 

73 def __class_getitem__(cls, params): 

74 if getattr(params, "__origin__", None) == typing.Union: 

75 return typing.GenericAlias(tuple, (params, Ellipsis)) 

76 elif isinstance(params, type): 

77 typing.GenericAlias(tuple, (params, Ellipsis)) 

78 # test if has len and is iterable 

79 elif isinstance(params, Iterable): 

80 if len(params) == 0: 

81 return tuple 

82 elif len(params) == 1: 

83 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 

84 else: 

85 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") 

86 

87 

88# TYPING: we allow `Any` here because the container is... universal 

89class UniversalContainer: 

90 """contains everything -- `x in UniversalContainer()` is always True""" 

91 

92 def __contains__(self, x: Any) -> bool: # pyright: ignore[reportAny] 

93 return True 

94 

95 

96def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] 

97 """checks if `x` is a `namedtuple` 

98 

99 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 

100 """ 

101 t: type = type(x) # pyright: ignore[reportUnknownVariableType, reportAny] 

102 b: tuple[type, ...] = t.__bases__ 

103 if len(b) != 1 or (b[0] is not tuple): 

104 return False 

105 f: Any = getattr(t, "_fields", None) 

106 if not isinstance(f, tuple): 

107 return False 

108 # fine that the type is unknown -- that's what we want to check 

109 return all(isinstance(n, str) for n in f) # pyright: ignore[reportUnknownVariableType] 

110 

111 

112T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn") 

113 

114 

115def try_catch( 

116 func: Callable[..., T_FuncTryCatchReturn], 

117) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: 

118 """wraps the function to catch exceptions, returns serialized error message on exception 

119 

120 returned func will return normal result on success, or error message on exception 

121 """ 

122 

123 @functools.wraps(func) 

124 def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # pyright: ignore[reportAny] 

125 try: 

126 return func(*args, **kwargs) 

127 except Exception as e: 

128 return f"{e.__class__.__name__}: {e}" 

129 

130 return newfunc 

131 

132 

133# TYPING: can we get rid of any of these? 

134def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: # pyright: ignore[reportAny] 

135 if isinstance(obj, typing.Mapping): 

136 return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] 

137 elif isinstance(obj, (bool, int, float, str)): 

138 return obj 

139 elif isinstance(obj, (tuple, list, Iterable)): 

140 return tuple(_recursive_hashify(v) for v in obj) # pyright: ignore[reportUnknownVariableType] 

141 else: 

142 if force: 

143 return str(obj) # pyright: ignore[reportAny] 

144 else: 

145 raise ValueError(f"cannot hashify:\n{obj}") 

146 

147 

148class SerializationException(Exception): 

149 pass 

150 

151 

152def string_as_lines(s: str | None) -> list[str]: 

153 """for easier reading of long strings in json, split up by newlines 

154 

155 sort of like how jupyter notebooks do it 

156 """ 

157 if s is None: 

158 return list() 

159 else: 

160 return s.splitlines(keepends=False) 

161 

162 

163def safe_getsource(func: Callable[..., Any]) -> list[str]: 

164 try: 

165 return string_as_lines(inspect.getsource(func)) 

166 except Exception as e: 

167 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") 

168 

169 

170# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises 

171def array_safe_eq(a: Any, b: Any) -> bool: # pyright: ignore[reportAny] 

172 """check if two objects are equal, account for if numpy arrays or torch tensors""" 

173 if a is b: 

174 return True 

175 

176 if type(a) is not type(b): # pyright: ignore[reportAny] 

177 return False 

178 

179 if ( 

180 str(type(a)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 

181 and str(type(b)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 

182 ) or ( 

183 str(type(a)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 

184 and str(type(b)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 

185 ): 

186 return (a == b).all() # pyright: ignore[reportAny] 

187 

188 if ( 

189 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny] 

190 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny] 

191 ): 

192 return a.equals(b) # pyright: ignore[reportAny] 

193 

194 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 

195 if len(a) == 0 and len(b) == 0: 

196 return True 

197 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 

198 

199 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 

200 return len(a) == len(b) and all( # pyright: ignore[reportUnknownArgumentType] 

201 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 

202 for k1, k2 in zip(a.keys(), b.keys()) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] 

203 ) 

204 

205 try: 

206 return bool(a == b) # pyright: ignore[reportAny] 

207 except (TypeError, ValueError) as e: 

208 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 

209 return NotImplemented # type: ignore[return-value] 

210 

211 

212# TYPING: see what can be done about so many `Any`s here 

213def dc_eq( 

214 dc1: Any, # pyright: ignore[reportAny] 

215 dc2: Any, # pyright: ignore[reportAny] 

216 except_when_class_mismatch: bool = False, 

217 false_when_class_mismatch: bool = True, 

218 except_when_field_mismatch: bool = False, 

219) -> bool: 

220 """ 

221 checks if two dataclasses which (might) hold numpy arrays are equal 

222 

223 # Parameters: 

224 

225 - `dc1`: the first dataclass 

226 - `dc2`: the second dataclass 

227 - `except_when_class_mismatch: bool` 

228 if `True`, will throw `TypeError` if the classes are different. 

229 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 

230 (default: `False`) 

231 - `false_when_class_mismatch: bool` 

232 only relevant if `except_when_class_mismatch` is `False`. 

233 if `True`, will return `False` if the classes are different. 

234 if `False`, will attempt to compare the fields. 

235 - `except_when_field_mismatch: bool` 

236 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 

237 if `True`, will throw `AttributeError` if the fields are different. 

238 (default: `False`) 

239 

240 # Returns: 

241 - `bool`: True if the dataclasses are equal, False otherwise 

242 

243 # Raises: 

244 - `TypeError`: if the dataclasses are of different classes 

245 - `AttributeError`: if the dataclasses have different fields 

246 

247 ``` 

248 [START] 

249 

250 ┌─────────────┐ 

251 │ dc1 is dc2? │───Yes───► (True) 

252 └──────┬──────┘ 

253 │No 

254 

255 ┌───────────────┐ 

256 │ classes match?│───Yes───► [compare field values] ───► (True/False) 

257 └──────┬────────┘ 

258 │No 

259 

260 ┌────────────────────────────┐ 

261 │ except_when_class_mismatch?│───Yes───► { raise TypeError } 

262 └─────────────┬──────────────┘ 

263 │No 

264 

265 ┌────────────────────────────┐ 

266 │ false_when_class_mismatch? │───Yes───► (False) 

267 └─────────────┬──────────────┘ 

268 │No 

269 

270 ┌────────────────────────────┐ 

271 │ except_when_field_mismatch?│───No────► [compare field values] 

272 └─────────────┬──────────────┘ 

273 │Yes 

274 

275 ┌───────────────┐ 

276 │ fields match? │───Yes───► [compare field values] 

277 └──────┬────────┘ 

278 │No 

279 

280 { raise AttributeError } 

281 ``` 

282 

283 """ 

284 if dc1 is dc2: 

285 return True 

286 

287 if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportAny] 

288 if except_when_class_mismatch: 

289 # if the classes don't match, raise an error 

290 raise TypeError( 

291 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny] 

292 ) 

293 if false_when_class_mismatch: 

294 # return False immediately without attempting field comparison 

295 return False 

296 # classes don't match but we'll try to compare fields anyway 

297 if except_when_field_mismatch: 

298 dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny] 

299 dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny] 

300 fields_match: bool = set(dc1_fields) == set(dc2_fields) 

301 if not fields_match: 

302 # if the fields don't match, raise an error 

303 raise AttributeError( 

304 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 

305 ) 

306 

307 return all( 

308 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny] 

309 for fld in dataclasses.fields(dc1) # pyright: ignore[reportAny] 

310 if fld.compare 

311 )