Coverage for muutils/json_serialize/util.py: 43%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""utilities for json_serialize""" 

2 

3from __future__ import annotations 

4 

5import dataclasses 

6import functools 

7import inspect 

8import sys 

9import typing 

10import warnings 

11from typing import Any, Callable, Iterable, Union 

12 

13_NUMPY_WORKING: bool 

14try: 

15 _NUMPY_WORKING = True 

16except ImportError: 

17 warnings.warn("numpy not found, cannot serialize numpy arrays!") 

18 _NUMPY_WORKING = False 

19 

20 

21BaseType = Union[ 

22 bool, 

23 int, 

24 float, 

25 str, 

26 None, 

27] 

28 

29JSONitem = Union[ 

30 BaseType, 

31 # mypy doesn't like recursive types, so we just go down a few levels manually 

32 typing.List[Union[BaseType, typing.List[Any], typing.Dict[str, Any]]], 

33 typing.Dict[str, Union[BaseType, typing.List[Any], typing.Dict[str, Any]]], 

34] 

35JSONdict = typing.Dict[str, JSONitem] 

36 

37Hashableitem = Union[bool, int, float, str, tuple] 

38 

39 

40_FORMAT_KEY: str = "__muutils_format__" 

41_REF_KEY: str = "$ref" 

42 

43# or if python version <3.9 

44if typing.TYPE_CHECKING or sys.version_info < (3, 9): 

45 MonoTuple = typing.Sequence 

46else: 

47 

48 class MonoTuple: 

49 """tuple type hint, but for a tuple of any length with all the same type""" 

50 

51 __slots__ = () 

52 

53 def __new__(cls, *args, **kwargs): 

54 raise TypeError("Type MonoTuple cannot be instantiated.") 

55 

56 def __init_subclass__(cls, *args, **kwargs): 

57 raise TypeError(f"Cannot subclass {cls.__module__}") 

58 

59 # idk why mypy thinks there is no such function in typing 

60 @typing._tp_cache # type: ignore 

61 def __class_getitem__(cls, params): 

62 if getattr(params, "__origin__", None) == typing.Union: 

63 return typing.GenericAlias(tuple, (params, Ellipsis)) 

64 elif isinstance(params, type): 

65 typing.GenericAlias(tuple, (params, Ellipsis)) 

66 # test if has len and is iterable 

67 elif isinstance(params, Iterable): 

68 if len(params) == 0: 

69 return tuple 

70 elif len(params) == 1: 

71 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 

72 else: 

73 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") 

74 

75 

76class UniversalContainer: 

77 """contains everything -- `x in UniversalContainer()` is always True""" 

78 

79 def __contains__(self, x: Any) -> bool: 

80 return True 

81 

82 

83def isinstance_namedtuple(x: Any) -> bool: 

84 """checks if `x` is a `namedtuple` 

85 

86 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 

87 """ 

88 t: type = type(x) 

89 b: tuple = t.__bases__ 

90 if len(b) != 1 or (b[0] is not tuple): 

91 return False 

92 f: Any = getattr(t, "_fields", None) 

93 if not isinstance(f, tuple): 

94 return False 

95 return all(isinstance(n, str) for n in f) 

96 

97 

98def try_catch(func: Callable): 

99 """wraps the function to catch exceptions, returns serialized error message on exception 

100 

101 returned func will return normal result on success, or error message on exception 

102 """ 

103 

104 @functools.wraps(func) 

105 def newfunc(*args, **kwargs): 

106 try: 

107 return func(*args, **kwargs) 

108 except Exception as e: 

109 return f"{e.__class__.__name__}: {e}" 

110 

111 return newfunc 

112 

113 

114def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: 

115 if isinstance(obj, typing.Mapping): 

116 return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) 

117 elif isinstance(obj, (tuple, list, Iterable)): 

118 return tuple(_recursive_hashify(v) for v in obj) 

119 elif isinstance(obj, (bool, int, float, str)): 

120 return obj 

121 else: 

122 if force: 

123 return str(obj) 

124 else: 

125 raise ValueError(f"cannot hashify:\n{obj}") 

126 

127 

128class SerializationException(Exception): 

129 pass 

130 

131 

132def string_as_lines(s: str | None) -> list[str]: 

133 """for easier reading of long strings in json, split up by newlines 

134 

135 sort of like how jupyter notebooks do it 

136 """ 

137 if s is None: 

138 return list() 

139 else: 

140 return s.splitlines(keepends=False) 

141 

142 

143def safe_getsource(func) -> list[str]: 

144 try: 

145 return string_as_lines(inspect.getsource(func)) 

146 except Exception as e: 

147 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") 

148 

149 

150# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises 

151def array_safe_eq(a: Any, b: Any) -> bool: 

152 """check if two objects are equal, account for if numpy arrays or torch tensors""" 

153 if a is b: 

154 return True 

155 

156 if type(a) is not type(b): 

157 return False 

158 

159 if ( 

160 str(type(a)) == "<class 'numpy.ndarray'>" 

161 and str(type(b)) == "<class 'numpy.ndarray'>" 

162 ) or ( 

163 str(type(a)) == "<class 'torch.Tensor'>" 

164 and str(type(b)) == "<class 'torch.Tensor'>" 

165 ): 

166 return (a == b).all() 

167 

168 if ( 

169 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" 

170 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" 

171 ): 

172 return a.equals(b) 

173 

174 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 

175 if len(a) == 0 and len(b) == 0: 

176 return True 

177 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 

178 

179 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 

180 return len(a) == len(b) and all( 

181 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 

182 for k1, k2 in zip(a.keys(), b.keys()) 

183 ) 

184 

185 try: 

186 return bool(a == b) 

187 except (TypeError, ValueError) as e: 

188 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 

189 return NotImplemented # type: ignore[return-value] 

190 

191 

192def dc_eq( 

193 dc1, 

194 dc2, 

195 except_when_class_mismatch: bool = False, 

196 false_when_class_mismatch: bool = True, 

197 except_when_field_mismatch: bool = False, 

198) -> bool: 

199 """ 

200 checks if two dataclasses which (might) hold numpy arrays are equal 

201 

202 # Parameters: 

203 

204 - `dc1`: the first dataclass 

205 - `dc2`: the second dataclass 

206 - `except_when_class_mismatch: bool` 

207 if `True`, will throw `TypeError` if the classes are different. 

208 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 

209 (default: `False`) 

210 - `false_when_class_mismatch: bool` 

211 only relevant if `except_when_class_mismatch` is `False`. 

212 if `True`, will return `False` if the classes are different. 

213 if `False`, will attempt to compare the fields. 

214 - `except_when_field_mismatch: bool` 

215 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 

216 if `True`, will throw `TypeError` if the fields are different. 

217 (default: `True`) 

218 

219 # Returns: 

220 - `bool`: True if the dataclasses are equal, False otherwise 

221 

222 # Raises: 

223 - `TypeError`: if the dataclasses are of different classes 

224 - `AttributeError`: if the dataclasses have different fields 

225 

226 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 

227 ``` 

228 [START] 

229 

230 ┌───────────┐ ┌─────────┐ 

231 │dc1 is dc2?├─►│ classes │ 

232 └──┬────────┘No│ match? │ 

233 ──── │ ├─────────┤ 

234 (True)◄──┘Yes │No │Yes 

235 ──── ▼ ▼ 

236 ┌────────────────┐ ┌────────────┐ 

237 │ except when │ │ fields keys│ 

238 │ class mismatch?│ │ match? │ 

239 ├───────────┬────┘ ├───────┬────┘ 

240 │Yes │No │No │Yes 

241 ▼ ▼ ▼ ▼ 

242 ─────────── ┌──────────┐ ┌────────┐ 

243 { raise } │ except │ │ field │ 

244 { TypeError } │ when │ │ values │ 

245 ─────────── │ field │ │ match? │ 

246 │ mismatch?│ ├────┬───┘ 

247 ├───────┬──┘ │ │Yes 

248 │Yes │No │No ▼ 

249 ▼ ▼ │ ──── 

250 ─────────────── ───── │ (True) 

251 { raise } (False)◄┘ ──── 

252 { AttributeError} ───── 

253 ─────────────── 

254 ``` 

255 

256 """ 

257 if dc1 is dc2: 

258 return True 

259 

260 if dc1.__class__ is not dc2.__class__: 

261 if except_when_class_mismatch: 

262 # if the classes don't match, raise an error 

263 raise TypeError( 

264 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 

265 ) 

266 if except_when_field_mismatch: 

267 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 

268 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 

269 fields_match: bool = set(dc1_fields) == set(dc2_fields) 

270 if not fields_match: 

271 # if the fields match, keep going 

272 raise AttributeError( 

273 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 

274 ) 

275 return False 

276 

277 return all( 

278 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 

279 for fld in dataclasses.fields(dc1) 

280 if fld.compare 

281 )