Coverage for muutils / json_serialize / json_serialize.py: 51%

67 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1"""provides the basic framework for json serialization of objects 

2 

3notably: 

4 

5- `SerializerHandler` defines how to serialize a specific type of object 

6- `JsonSerializer` handles configuration for which handlers to use 

7- `json_serialize` provides the default configuration if you don't care -- call it on any object! 

8 

9""" 

10 

11from __future__ import annotations 

12 

13import inspect 

14import warnings 

15from dataclasses import dataclass, is_dataclass 

16from pathlib import Path 

17from typing import ( 

18 TYPE_CHECKING, 

19 Any, 

20 Callable, 

21 Iterable, 

22 Mapping, 

23 Set, 

24 Union, 

25 cast, 

26 overload, 

27) 

28 

29from muutils.errormode import ErrorMode 

30 

31if TYPE_CHECKING: 

32 # always need array.py for type checking 

33 from muutils.json_serialize.array import ArrayMode, serialize_array 

34else: 

35 try: 

36 from muutils.json_serialize.array import ArrayMode, serialize_array 

37 except ImportError as e: 

38 # TYPING: obviously, these types are all wrong if we can't import array.py 

39 ArrayMode = str # type: ignore[misc] 

40 serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731 

41 warnings.warn( 

42 f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", 

43 ImportWarning, 

44 ) 

45 

46from muutils.json_serialize.types import ( 

47 _FORMAT_KEY, 

48 Hashableitem, 

49) # pyright: ignore[reportPrivateUsage] 

50 

51from muutils.json_serialize.util import ( 

52 JSONdict, 

53 JSONitem, 

54 MonoTuple, 

55 SerializationException, 

56 _recursive_hashify, # pyright: ignore[reportPrivateUsage, reportUnknownVariableType] 

57 isinstance_namedtuple, 

58 safe_getsource, 

59 string_as_lines, 

60 try_catch, 

61) 

62 

63# pylint: disable=protected-access 

64 

65SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = ( 

66 "__name__", 

67 "__doc__", 

68 "__module__", 

69 "__class__", 

70 "__dict__", 

71 "__annotations__", 

72) 

73 

74SERIALIZER_SPECIAL_FUNCS: dict[str, Callable[..., str | list[str]]] = { 

75 "str": str, 

76 "dir": dir, 

77 "type": try_catch(lambda x: str(type(x).__name__)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] 

78 "repr": try_catch(lambda x: repr(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] 

79 "code": try_catch(lambda x: inspect.getsource(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] 

80 "sourcefile": try_catch(lambda x: str(inspect.getsourcefile(x))), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] 

81} 

82 

83SERIALIZE_DIRECT_AS_STR: Set[str] = { 

84 "<class 'torch.device'>", 

85 "<class 'torch.dtype'>", 

86} 

87 

88ObjectPath = MonoTuple[Union[str, int]] 

89 

90 

91@dataclass 

92class SerializerHandler: 

93 """a handler for a specific type of object 

94 

95 # Parameters: 

96 - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler 

97 - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object 

98 - `desc : str` description of the handler (optional) 

99 """ 

100 

101 # (self_config, object) -> whether to use this handler 

102 check: Callable[["JsonSerializer", Any, ObjectPath], bool] 

103 # (self_config, object, path) -> serialized object 

104 serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem] 

105 # unique identifier for the handler 

106 uid: str 

107 # description of this serializer 

108 desc: str 

109 

110 def serialize(self) -> JSONdict: 

111 """serialize the handler info""" 

112 return { 

113 # get the code and doc of the check function 

114 "check": { 

115 "code": safe_getsource(self.check), 

116 "doc": string_as_lines(self.check.__doc__), 

117 }, 

118 # get the code and doc of the load function 

119 "serialize_func": { 

120 "code": safe_getsource(self.serialize_func), 

121 "doc": string_as_lines(self.serialize_func.__doc__), 

122 }, 

123 # get the uid, source_pckg, priority, and desc 

124 "uid": str(self.uid), 

125 "source_pckg": getattr(self.serialize_func, "source_pckg", None), 

126 "__module__": getattr(self.serialize_func, "__module__", None), 

127 "desc": str(self.desc), 

128 } 

129 

130 

131BASE_HANDLERS: MonoTuple[SerializerHandler] = ( 

132 SerializerHandler( 

133 check=lambda self, obj, path: isinstance( 

134 obj, (bool, int, float, str, type(None)) 

135 ), 

136 serialize_func=lambda self, obj, path: obj, 

137 uid="base types", 

138 desc="base types (bool, int, float, str, None)", 

139 ), 

140 SerializerHandler( 

141 check=lambda self, obj, path: isinstance(obj, Mapping), 

142 serialize_func=lambda self, obj, path: { 

143 str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items() 

144 }, 

145 uid="dictionaries", 

146 desc="dictionaries", 

147 ), 

148 SerializerHandler( 

149 check=lambda self, obj, path: isinstance_namedtuple(obj), 

150 serialize_func=lambda self, obj, path: { 

151 str(k): self.json_serialize(v, tuple(path) + (k,)) 

152 for k, v in obj._asdict().items() 

153 }, 

154 uid="namedtuple -> dict", 

155 desc="namedtuples as dicts", 

156 ), 

157 SerializerHandler( 

158 check=lambda self, obj, path: isinstance(obj, (list, tuple)), 

159 serialize_func=lambda self, obj, path: [ 

160 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) 

161 ], 

162 uid="(list, tuple) -> list", 

163 desc="lists and tuples as lists", 

164 ), 

165) 

166 

167 

168def _serialize_override_serialize_func( 

169 self: "JsonSerializer", obj: Any, path: ObjectPath 

170) -> JSONitem: 

171 # obj_cls: type = type(obj) 

172 # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self): 

173 # obj_cls._register_self() 

174 

175 # get the serialized object 

176 return obj.serialize() 

177 

178 

179DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + ( 

180 SerializerHandler( 

181 # TODO: allow for custom serialization handler name 

182 check=lambda self, obj, path: ( 

183 hasattr(obj, "serialize") and callable(obj.serialize) 

184 ), 

185 serialize_func=_serialize_override_serialize_func, 

186 uid=".serialize override", 

187 desc="objects with .serialize method", 

188 ), 

189 SerializerHandler( 

190 check=lambda self, obj, path: is_dataclass(obj), 

191 serialize_func=lambda self, obj, path: { 

192 k: self.json_serialize(getattr(obj, k), tuple(path) + (k,)) 

193 for k in obj.__dataclass_fields__ 

194 }, 

195 uid="dataclass -> dict", 

196 desc="dataclasses as dicts", 

197 ), 

198 SerializerHandler( 

199 check=lambda self, obj, path: isinstance(obj, Path), 

200 serialize_func=lambda self, obj, path: obj.as_posix(), 

201 uid="path -> str", 

202 desc="Path objects as posix strings", 

203 ), 

204 SerializerHandler( 

205 check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR, 

206 serialize_func=lambda self, obj, path: str(obj), 

207 uid="obj -> str(obj)", 

208 desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings", 

209 ), 

210 SerializerHandler( 

211 check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>", 

212 serialize_func=lambda self, obj, path: cast( 

213 JSONitem, serialize_array(self, obj, path=path) 

214 ), 

215 uid="numpy.ndarray", 

216 desc="numpy arrays", 

217 ), 

218 SerializerHandler( 

219 check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>", 

220 serialize_func=lambda self, obj, path: cast( 

221 JSONitem, 

222 serialize_array( 

223 self, 

224 obj.detach().cpu(), 

225 path=path, # pyright: ignore[reportAny] 

226 ), 

227 ), 

228 uid="torch.Tensor", 

229 desc="pytorch tensors", 

230 ), 

231 SerializerHandler( 

232 check=lambda self, obj, path: ( 

233 str(type(obj)) == "<class 'pandas.core.frame.DataFrame'>" 

234 ), 

235 # TYPING: type checkers have no idea that obj is a DataFrame here 

236 serialize_func=lambda self, obj, path: { # pyright: ignore[reportArgumentType, reportAny] 

237 _FORMAT_KEY: "pandas.DataFrame", # type: ignore[misc] 

238 "columns": obj.columns.tolist(), # pyright: ignore[reportAny] 

239 "data": obj.to_dict(orient="records"), # pyright: ignore[reportAny] 

240 "path": path, 

241 }, 

242 uid="pandas.DataFrame", 

243 desc="pandas DataFrames", 

244 ), 

245 SerializerHandler( 

246 check=lambda self, obj, path: isinstance(obj, (set, frozenset)), 

247 serialize_func=lambda self, obj, path: { 

248 _FORMAT_KEY: "set" if isinstance(obj, set) else "frozenset", # type: ignore[misc] 

249 "data": [ 

250 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) 

251 ], 

252 }, 

253 uid="set -> dict[_FORMAT_KEY: 'set', data: list(...)]", 

254 desc="sets as dicts with format key", 

255 ), 

256 SerializerHandler( 

257 check=lambda self, obj, path: ( 

258 isinstance(obj, Iterable) and not isinstance(obj, (list, tuple, str)) 

259 ), 

260 serialize_func=lambda self, obj, path: [ 

261 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) 

262 ], 

263 uid="Iterable -> list", 

264 desc="Iterables (not lists/tuples/strings) as lists", 

265 ), 

266 SerializerHandler( 

267 check=lambda self, obj, path: True, 

268 serialize_func=lambda self, obj, path: { 

269 **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS}, # type: ignore[typeddict-item] 

270 **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()}, 

271 }, 

272 uid="fallback", 

273 desc="fallback handler -- serialize object attributes and special functions as strings", 

274 ), 

275) 

276 

277 

278class JsonSerializer: 

279 """Json serialization class (holds configs) 

280 

281 # Parameters: 

282 - `array_mode : ArrayMode` 

283 how to write arrays 

284 (defaults to `"array_list_meta"`) 

285 - `error_mode : ErrorMode` 

286 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 

287 (defaults to `"except"`) 

288 - `handlers_pre : MonoTuple[SerializerHandler]` 

289 handlers to use before the default handlers 

290 (defaults to `tuple()`) 

291 - `handlers_default : MonoTuple[SerializerHandler]` 

292 default handlers to use 

293 (defaults to `DEFAULT_HANDLERS`) 

294 - `write_only_format : bool` 

295 changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 

296 (defaults to `False`) 

297 

298 # Raises: 

299 - `ValueError`: on init, if `args` is not empty 

300 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 

301 

302 """ 

303 

304 def __init__( 

305 self, 

306 *args: None, 

307 array_mode: "ArrayMode" = "array_list_meta", 

308 error_mode: ErrorMode = ErrorMode.EXCEPT, 

309 handlers_pre: MonoTuple[SerializerHandler] = (), 

310 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 

311 write_only_format: bool = False, 

312 ): 

313 if len(args) > 0: 

314 raise ValueError( 

315 f"JsonSerializer takes no positional arguments!\n{args = }" 

316 ) 

317 

318 self.array_mode: "ArrayMode" = array_mode 

319 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 

320 self.write_only_format: bool = write_only_format 

321 # join up the handlers 

322 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 

323 handlers_default 

324 ) 

325 

326 @overload 

327 def json_serialize( 

328 self, obj: Mapping[str, Any], path: ObjectPath = () 

329 ) -> JSONdict: ... 

330 @overload 

331 def json_serialize(self, obj: list, path: ObjectPath = ()) -> list: ... 

332 # @overload # pyright: ignore[reportOverlappingOverload] 

333 # def json_serialize(self, obj: set, path: ObjectPath = ()) -> _SerializedSet: ... 

334 # @overload 

335 # def json_serialize( 

336 # self, obj: frozenset, path: ObjectPath = () 

337 # ) -> _SerializedFrozenset: ... 

338 @overload 

339 def json_serialize(self, obj: Any, path: ObjectPath = ()) -> JSONitem: ... 

340 def json_serialize( 

341 self, 

342 obj: Any, # pyright: ignore[reportAny] 

343 path: ObjectPath = (), 

344 ) -> JSONitem: 

345 handler = None 

346 try: 

347 for handler in self.handlers: 

348 if handler.check(self, obj, path): 

349 output: JSONitem = handler.serialize_func(self, obj, path) 

350 if self.write_only_format: 

351 if isinstance(output, dict) and _FORMAT_KEY in output: 

352 # TYPING: JSONitem has no idea that _FORMAT_KEY is str 

353 new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore # pyright: ignore[reportAssignmentType] 

354 output["__write_format__"] = new_fmt # type: ignore 

355 return output 

356 

357 raise ValueError(f"no handler found for object with {type(obj) = }") # pyright: ignore[reportAny] 

358 

359 except Exception as e: 

360 if self.error_mode == ErrorMode.EXCEPT: 

361 obj_str: str = repr(obj) # pyright: ignore[reportAny] 

362 if len(obj_str) > 1000: 

363 obj_str = obj_str[:1000] + "..." 

364 handler_uid = handler.uid if handler else "no handler matched" 

365 raise SerializationException( 

366 f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}" 

367 ) from e 

368 elif self.error_mode == ErrorMode.WARN: 

369 warnings.warn( 

370 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 

371 ) 

372 

373 return repr(obj) # pyright: ignore[reportAny] 

374 

375 def hashify( 

376 self, 

377 obj: Any, # pyright: ignore[reportAny] 

378 path: ObjectPath = (), 

379 force: bool = True, 

380 ) -> Hashableitem: 

381 """try to turn any object into something hashable""" 

382 data = self.json_serialize(obj, path=path) 

383 

384 # recursive hashify, turning dicts and lists into tuples 

385 return _recursive_hashify(data, force=force) 

386 

387 

388GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer() 

389 

390 

391@overload 

392def json_serialize(obj: Mapping[str, Any], path: ObjectPath = ()) -> JSONdict: ... 

393@overload 

394def json_serialize(obj: list, path: ObjectPath = ()) -> list: ... 

395@overload # pyright: ignore[reportOverlappingOverload] 

396# def json_serialize(obj: set, path: ObjectPath = ()) -> _SerializedSet: ... 

397# @overload 

398# def json_serialize(obj: frozenset, path: ObjectPath = ()) -> _SerializedFrozenset: ... 

399@overload 

400def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: ... 

401def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: # pyright: ignore[reportAny] 

402 """serialize object to json-serializable object with default config""" 

403 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)