Coverage for muutils/json_serialize/json_serialize.py: 28%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""provides the basic framework for json serialization of objects 

2 

3notably: 

4 

5- `SerializerHandler` defines how to serialize a specific type of object 

6- `JsonSerializer` handles configuration for which handlers to use 

7- `json_serialize` provides the default configuration if you don't care -- call it on any object! 

8 

9""" 

10 

11from __future__ import annotations 

12 

13import inspect 

14import warnings 

15from dataclasses import dataclass, is_dataclass 

16from pathlib import Path 

17from typing import Any, Callable, Iterable, Mapping, Set, Union 

18 

19from muutils.errormode import ErrorMode 

20 

21try: 

22 from muutils.json_serialize.array import ArrayMode, serialize_array 

23except ImportError as e: 

24 ArrayMode = str # type: ignore[misc] 

25 serialize_array = lambda *args, **kwargs: None # noqa: E731 

26 warnings.warn( 

27 f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", 

28 ImportWarning, 

29 ) 

30 

31from muutils.json_serialize.util import ( 

32 _FORMAT_KEY, 

33 Hashableitem, 

34 JSONitem, 

35 MonoTuple, 

36 SerializationException, 

37 _recursive_hashify, 

38 isinstance_namedtuple, 

39 safe_getsource, 

40 string_as_lines, 

41 try_catch, 

42) 

43 

44# pylint: disable=protected-access 

45 

46SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = ( 

47 "__name__", 

48 "__doc__", 

49 "__module__", 

50 "__class__", 

51 "__dict__", 

52 "__annotations__", 

53) 

54 

55SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = { 

56 "str": str, 

57 "dir": dir, 

58 "type": try_catch(lambda x: str(type(x).__name__)), 

59 "repr": try_catch(lambda x: repr(x)), 

60 "code": try_catch(lambda x: inspect.getsource(x)), 

61 "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)), 

62} 

63 

64SERIALIZE_DIRECT_AS_STR: Set[str] = { 

65 "<class 'torch.device'>", 

66 "<class 'torch.dtype'>", 

67} 

68 

69ObjectPath = MonoTuple[Union[str, int]] 

70 

71 

72@dataclass 

73class SerializerHandler: 

74 """a handler for a specific type of object 

75 

76 # Parameters: 

77 - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler 

78 - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object 

79 - `desc : str` description of the handler (optional) 

80 """ 

81 

82 # (self_config, object) -> whether to use this handler 

83 check: Callable[["JsonSerializer", Any, ObjectPath], bool] 

84 # (self_config, object, path) -> serialized object 

85 serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem] 

86 # unique identifier for the handler 

87 uid: str 

88 # description of this serializer 

89 desc: str 

90 

91 def serialize(self) -> dict: 

92 """serialize the handler info""" 

93 return { 

94 # get the code and doc of the check function 

95 "check": { 

96 "code": safe_getsource(self.check), 

97 "doc": string_as_lines(self.check.__doc__), 

98 }, 

99 # get the code and doc of the load function 

100 "serialize_func": { 

101 "code": safe_getsource(self.serialize_func), 

102 "doc": string_as_lines(self.serialize_func.__doc__), 

103 }, 

104 # get the uid, source_pckg, priority, and desc 

105 "uid": str(self.uid), 

106 "source_pckg": getattr(self.serialize_func, "source_pckg", None), 

107 "__module__": getattr(self.serialize_func, "__module__", None), 

108 "desc": str(self.desc), 

109 } 

110 

111 

112BASE_HANDLERS: MonoTuple[SerializerHandler] = ( 

113 SerializerHandler( 

114 check=lambda self, obj, path: isinstance( 

115 obj, (bool, int, float, str, type(None)) 

116 ), 

117 serialize_func=lambda self, obj, path: obj, 

118 uid="base types", 

119 desc="base types (bool, int, float, str, None)", 

120 ), 

121 SerializerHandler( 

122 check=lambda self, obj, path: isinstance(obj, Mapping), 

123 serialize_func=lambda self, obj, path: { 

124 str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items() 

125 }, 

126 uid="dictionaries", 

127 desc="dictionaries", 

128 ), 

129 SerializerHandler( 

130 check=lambda self, obj, path: isinstance(obj, (list, tuple)), 

131 serialize_func=lambda self, obj, path: [ 

132 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) 

133 ], 

134 uid="(list, tuple) -> list", 

135 desc="lists and tuples as lists", 

136 ), 

137) 

138 

139 

140def _serialize_override_serialize_func( 

141 self: "JsonSerializer", obj: Any, path: ObjectPath 

142) -> JSONitem: 

143 # obj_cls: type = type(obj) 

144 # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self): 

145 # obj_cls._register_self() 

146 

147 # get the serialized object 

148 return obj.serialize() 

149 

150 

151DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + ( 

152 SerializerHandler( 

153 # TODO: allow for custom serialization handler name 

154 check=lambda self, obj, path: hasattr(obj, "serialize") 

155 and callable(obj.serialize), 

156 serialize_func=_serialize_override_serialize_func, 

157 uid=".serialize override", 

158 desc="objects with .serialize method", 

159 ), 

160 SerializerHandler( 

161 check=lambda self, obj, path: isinstance_namedtuple(obj), 

162 serialize_func=lambda self, obj, path: self.json_serialize(dict(obj._asdict())), 

163 uid="namedtuple -> dict", 

164 desc="namedtuples as dicts", 

165 ), 

166 SerializerHandler( 

167 check=lambda self, obj, path: is_dataclass(obj), 

168 serialize_func=lambda self, obj, path: { 

169 k: self.json_serialize(getattr(obj, k), tuple(path) + (k,)) 

170 for k in obj.__dataclass_fields__ 

171 }, 

172 uid="dataclass -> dict", 

173 desc="dataclasses as dicts", 

174 ), 

175 SerializerHandler( 

176 check=lambda self, obj, path: isinstance(obj, Path), 

177 serialize_func=lambda self, obj, path: obj.as_posix(), 

178 uid="path -> str", 

179 desc="Path objects as posix strings", 

180 ), 

181 SerializerHandler( 

182 check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR, 

183 serialize_func=lambda self, obj, path: str(obj), 

184 uid="obj -> str(obj)", 

185 desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings", 

186 ), 

187 SerializerHandler( 

188 check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>", 

189 serialize_func=lambda self, obj, path: serialize_array(self, obj, path=path), 

190 uid="numpy.ndarray", 

191 desc="numpy arrays", 

192 ), 

193 SerializerHandler( 

194 check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>", 

195 serialize_func=lambda self, obj, path: serialize_array( 

196 self, obj.detach().cpu(), path=path 

197 ), 

198 uid="torch.Tensor", 

199 desc="pytorch tensors", 

200 ), 

201 SerializerHandler( 

202 check=lambda self, obj, path: ( 

203 str(type(obj)) == "<class 'pandas.core.frame.DataFrame'>" 

204 ), 

205 serialize_func=lambda self, obj, path: { 

206 _FORMAT_KEY: "pandas.DataFrame", 

207 "columns": obj.columns.tolist(), 

208 "data": obj.to_dict(orient="records"), 

209 "path": path, # type: ignore 

210 }, 

211 uid="pandas.DataFrame", 

212 desc="pandas DataFrames", 

213 ), 

214 SerializerHandler( 

215 check=lambda self, obj, path: isinstance(obj, (set, list, tuple)) 

216 or isinstance(obj, Iterable), 

217 serialize_func=lambda self, obj, path: [ 

218 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) 

219 ], 

220 uid="(set, list, tuple, Iterable) -> list", 

221 desc="sets, lists, tuples, and Iterables as lists", 

222 ), 

223 SerializerHandler( 

224 check=lambda self, obj, path: True, 

225 serialize_func=lambda self, obj, path: { 

226 **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS}, 

227 **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()}, 

228 }, 

229 uid="fallback", 

230 desc="fallback handler -- serialize object attributes and special functions as strings", 

231 ), 

232) 

233 

234 

235class JsonSerializer: 

236 """Json serialization class (holds configs) 

237 

238 # Parameters: 

239 - `array_mode : ArrayMode` 

240 how to write arrays 

241 (defaults to `"array_list_meta"`) 

242 - `error_mode : ErrorMode` 

243 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 

244 (defaults to `"except"`) 

245 - `handlers_pre : MonoTuple[SerializerHandler]` 

246 handlers to use before the default handlers 

247 (defaults to `tuple()`) 

248 - `handlers_default : MonoTuple[SerializerHandler]` 

249 default handlers to use 

250 (defaults to `DEFAULT_HANDLERS`) 

251 - `write_only_format : bool` 

252 changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 

253 (defaults to `False`) 

254 

255 # Raises: 

256 - `ValueError`: on init, if `args` is not empty 

257 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 

258 

259 """ 

260 

261 def __init__( 

262 self, 

263 *args, 

264 array_mode: ArrayMode = "array_list_meta", 

265 error_mode: ErrorMode = ErrorMode.EXCEPT, 

266 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 

267 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 

268 write_only_format: bool = False, 

269 ): 

270 if len(args) > 0: 

271 raise ValueError( 

272 f"JsonSerializer takes no positional arguments!\n{args = }" 

273 ) 

274 

275 self.array_mode: ArrayMode = array_mode 

276 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 

277 self.write_only_format: bool = write_only_format 

278 # join up the handlers 

279 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 

280 handlers_default 

281 ) 

282 

283 def json_serialize( 

284 self, 

285 obj: Any, 

286 path: ObjectPath = tuple(), 

287 ) -> JSONitem: 

288 try: 

289 for handler in self.handlers: 

290 if handler.check(self, obj, path): 

291 output: JSONitem = handler.serialize_func(self, obj, path) 

292 if self.write_only_format: 

293 if isinstance(output, dict) and _FORMAT_KEY in output: 

294 new_fmt: JSONitem = output.pop(_FORMAT_KEY) 

295 output["__write_format__"] = new_fmt 

296 return output 

297 

298 raise ValueError(f"no handler found for object with {type(obj) = }") 

299 

300 except Exception as e: 

301 if self.error_mode == "except": 

302 obj_str: str = repr(obj) 

303 if len(obj_str) > 1000: 

304 obj_str = obj_str[:1000] + "..." 

305 raise SerializationException( 

306 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 

307 ) from e 

308 elif self.error_mode == "warn": 

309 warnings.warn( 

310 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 

311 ) 

312 

313 return repr(obj) 

314 

315 def hashify( 

316 self, 

317 obj: Any, 

318 path: ObjectPath = tuple(), 

319 force: bool = True, 

320 ) -> Hashableitem: 

321 """try to turn any object into something hashable""" 

322 data = self.json_serialize(obj, path=path) 

323 

324 # recursive hashify, turning dicts and lists into tuples 

325 return _recursive_hashify(data, force=force) 

326 

327 

328GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer() 

329 

330 

331def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: 

332 """serialize object to json-serializable object with default config""" 

333 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)