Coverage for muutils / json_serialize / json_serialize.py: 51%
67 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
1"""provides the basic framework for json serialization of objects
3notably:
5- `SerializerHandler` defines how to serialize a specific type of object
6- `JsonSerializer` handles configuration for which handlers to use
7- `json_serialize` provides the default configuration if you don't care -- call it on any object!
9"""
11from __future__ import annotations
13import inspect
14import warnings
15from dataclasses import dataclass, is_dataclass
16from pathlib import Path
17from typing import (
18 TYPE_CHECKING,
19 Any,
20 Callable,
21 Iterable,
22 Mapping,
23 Set,
24 Union,
25 cast,
26 overload,
27)
29from muutils.errormode import ErrorMode
31if TYPE_CHECKING:
32 # always need array.py for type checking
33 from muutils.json_serialize.array import ArrayMode, serialize_array
34else:
35 try:
36 from muutils.json_serialize.array import ArrayMode, serialize_array
37 except ImportError as e:
38 # TYPING: obviously, these types are all wrong if we can't import array.py
39 ArrayMode = str # type: ignore[misc]
40 serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731
41 warnings.warn(
42 f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}",
43 ImportWarning,
44 )
46from muutils.json_serialize.types import (
47 _FORMAT_KEY,
48 Hashableitem,
49) # pyright: ignore[reportPrivateUsage]
51from muutils.json_serialize.util import (
52 JSONdict,
53 JSONitem,
54 MonoTuple,
55 SerializationException,
56 _recursive_hashify, # pyright: ignore[reportPrivateUsage, reportUnknownVariableType]
57 isinstance_namedtuple,
58 safe_getsource,
59 string_as_lines,
60 try_catch,
61)
63# pylint: disable=protected-access
65SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = (
66 "__name__",
67 "__doc__",
68 "__module__",
69 "__class__",
70 "__dict__",
71 "__annotations__",
72)
74SERIALIZER_SPECIAL_FUNCS: dict[str, Callable[..., str | list[str]]] = {
75 "str": str,
76 "dir": dir,
77 "type": try_catch(lambda x: str(type(x).__name__)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType]
78 "repr": try_catch(lambda x: repr(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType]
79 "code": try_catch(lambda x: inspect.getsource(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType]
80 "sourcefile": try_catch(lambda x: str(inspect.getsourcefile(x))), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType]
81}
83SERIALIZE_DIRECT_AS_STR: Set[str] = {
84 "<class 'torch.device'>",
85 "<class 'torch.dtype'>",
86}
88ObjectPath = MonoTuple[Union[str, int]]
91@dataclass
92class SerializerHandler:
93 """a handler for a specific type of object
95 # Parameters:
96 - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
97 - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
98 - `desc : str` description of the handler (optional)
99 """
101 # (self_config, object) -> whether to use this handler
102 check: Callable[["JsonSerializer", Any, ObjectPath], bool]
103 # (self_config, object, path) -> serialized object
104 serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem]
105 # unique identifier for the handler
106 uid: str
107 # description of this serializer
108 desc: str
110 def serialize(self) -> JSONdict:
111 """serialize the handler info"""
112 return {
113 # get the code and doc of the check function
114 "check": {
115 "code": safe_getsource(self.check),
116 "doc": string_as_lines(self.check.__doc__),
117 },
118 # get the code and doc of the load function
119 "serialize_func": {
120 "code": safe_getsource(self.serialize_func),
121 "doc": string_as_lines(self.serialize_func.__doc__),
122 },
123 # get the uid, source_pckg, priority, and desc
124 "uid": str(self.uid),
125 "source_pckg": getattr(self.serialize_func, "source_pckg", None),
126 "__module__": getattr(self.serialize_func, "__module__", None),
127 "desc": str(self.desc),
128 }
131BASE_HANDLERS: MonoTuple[SerializerHandler] = (
132 SerializerHandler(
133 check=lambda self, obj, path: isinstance(
134 obj, (bool, int, float, str, type(None))
135 ),
136 serialize_func=lambda self, obj, path: obj,
137 uid="base types",
138 desc="base types (bool, int, float, str, None)",
139 ),
140 SerializerHandler(
141 check=lambda self, obj, path: isinstance(obj, Mapping),
142 serialize_func=lambda self, obj, path: {
143 str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items()
144 },
145 uid="dictionaries",
146 desc="dictionaries",
147 ),
148 SerializerHandler(
149 check=lambda self, obj, path: isinstance_namedtuple(obj),
150 serialize_func=lambda self, obj, path: {
151 str(k): self.json_serialize(v, tuple(path) + (k,))
152 for k, v in obj._asdict().items()
153 },
154 uid="namedtuple -> dict",
155 desc="namedtuples as dicts",
156 ),
157 SerializerHandler(
158 check=lambda self, obj, path: isinstance(obj, (list, tuple)),
159 serialize_func=lambda self, obj, path: [
160 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
161 ],
162 uid="(list, tuple) -> list",
163 desc="lists and tuples as lists",
164 ),
165)
168def _serialize_override_serialize_func(
169 self: "JsonSerializer", obj: Any, path: ObjectPath
170) -> JSONitem:
171 # obj_cls: type = type(obj)
172 # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self):
173 # obj_cls._register_self()
175 # get the serialized object
176 return obj.serialize()
179DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + (
180 SerializerHandler(
181 # TODO: allow for custom serialization handler name
182 check=lambda self, obj, path: (
183 hasattr(obj, "serialize") and callable(obj.serialize)
184 ),
185 serialize_func=_serialize_override_serialize_func,
186 uid=".serialize override",
187 desc="objects with .serialize method",
188 ),
189 SerializerHandler(
190 check=lambda self, obj, path: is_dataclass(obj),
191 serialize_func=lambda self, obj, path: {
192 k: self.json_serialize(getattr(obj, k), tuple(path) + (k,))
193 for k in obj.__dataclass_fields__
194 },
195 uid="dataclass -> dict",
196 desc="dataclasses as dicts",
197 ),
198 SerializerHandler(
199 check=lambda self, obj, path: isinstance(obj, Path),
200 serialize_func=lambda self, obj, path: obj.as_posix(),
201 uid="path -> str",
202 desc="Path objects as posix strings",
203 ),
204 SerializerHandler(
205 check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR,
206 serialize_func=lambda self, obj, path: str(obj),
207 uid="obj -> str(obj)",
208 desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings",
209 ),
210 SerializerHandler(
211 check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>",
212 serialize_func=lambda self, obj, path: cast(
213 JSONitem, serialize_array(self, obj, path=path)
214 ),
215 uid="numpy.ndarray",
216 desc="numpy arrays",
217 ),
218 SerializerHandler(
219 check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>",
220 serialize_func=lambda self, obj, path: cast(
221 JSONitem,
222 serialize_array(
223 self,
224 obj.detach().cpu(),
225 path=path, # pyright: ignore[reportAny]
226 ),
227 ),
228 uid="torch.Tensor",
229 desc="pytorch tensors",
230 ),
231 SerializerHandler(
232 check=lambda self, obj, path: (
233 str(type(obj)) == "<class 'pandas.core.frame.DataFrame'>"
234 ),
235 # TYPING: type checkers have no idea that obj is a DataFrame here
236 serialize_func=lambda self, obj, path: { # pyright: ignore[reportArgumentType, reportAny]
237 _FORMAT_KEY: "pandas.DataFrame", # type: ignore[misc]
238 "columns": obj.columns.tolist(), # pyright: ignore[reportAny]
239 "data": obj.to_dict(orient="records"), # pyright: ignore[reportAny]
240 "path": path,
241 },
242 uid="pandas.DataFrame",
243 desc="pandas DataFrames",
244 ),
245 SerializerHandler(
246 check=lambda self, obj, path: isinstance(obj, (set, frozenset)),
247 serialize_func=lambda self, obj, path: {
248 _FORMAT_KEY: "set" if isinstance(obj, set) else "frozenset", # type: ignore[misc]
249 "data": [
250 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
251 ],
252 },
253 uid="set -> dict[_FORMAT_KEY: 'set', data: list(...)]",
254 desc="sets as dicts with format key",
255 ),
256 SerializerHandler(
257 check=lambda self, obj, path: (
258 isinstance(obj, Iterable) and not isinstance(obj, (list, tuple, str))
259 ),
260 serialize_func=lambda self, obj, path: [
261 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
262 ],
263 uid="Iterable -> list",
264 desc="Iterables (not lists/tuples/strings) as lists",
265 ),
266 SerializerHandler(
267 check=lambda self, obj, path: True,
268 serialize_func=lambda self, obj, path: {
269 **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS}, # type: ignore[typeddict-item]
270 **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()},
271 },
272 uid="fallback",
273 desc="fallback handler -- serialize object attributes and special functions as strings",
274 ),
275)
278class JsonSerializer:
279 """Json serialization class (holds configs)
281 # Parameters:
282 - `array_mode : ArrayMode`
283 how to write arrays
284 (defaults to `"array_list_meta"`)
285 - `error_mode : ErrorMode`
286 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
287 (defaults to `"except"`)
288 - `handlers_pre : MonoTuple[SerializerHandler]`
289 handlers to use before the default handlers
290 (defaults to `tuple()`)
291 - `handlers_default : MonoTuple[SerializerHandler]`
292 default handlers to use
293 (defaults to `DEFAULT_HANDLERS`)
294 - `write_only_format : bool`
295 changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
296 (defaults to `False`)
298 # Raises:
299 - `ValueError`: on init, if `args` is not empty
300 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
302 """
304 def __init__(
305 self,
306 *args: None,
307 array_mode: "ArrayMode" = "array_list_meta",
308 error_mode: ErrorMode = ErrorMode.EXCEPT,
309 handlers_pre: MonoTuple[SerializerHandler] = (),
310 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
311 write_only_format: bool = False,
312 ):
313 if len(args) > 0:
314 raise ValueError(
315 f"JsonSerializer takes no positional arguments!\n{args = }"
316 )
318 self.array_mode: "ArrayMode" = array_mode
319 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
320 self.write_only_format: bool = write_only_format
321 # join up the handlers
322 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
323 handlers_default
324 )
326 @overload
327 def json_serialize(
328 self, obj: Mapping[str, Any], path: ObjectPath = ()
329 ) -> JSONdict: ...
330 @overload
331 def json_serialize(self, obj: list, path: ObjectPath = ()) -> list: ...
332 # @overload # pyright: ignore[reportOverlappingOverload]
333 # def json_serialize(self, obj: set, path: ObjectPath = ()) -> _SerializedSet: ...
334 # @overload
335 # def json_serialize(
336 # self, obj: frozenset, path: ObjectPath = ()
337 # ) -> _SerializedFrozenset: ...
338 @overload
339 def json_serialize(self, obj: Any, path: ObjectPath = ()) -> JSONitem: ...
340 def json_serialize(
341 self,
342 obj: Any, # pyright: ignore[reportAny]
343 path: ObjectPath = (),
344 ) -> JSONitem:
345 handler = None
346 try:
347 for handler in self.handlers:
348 if handler.check(self, obj, path):
349 output: JSONitem = handler.serialize_func(self, obj, path)
350 if self.write_only_format:
351 if isinstance(output, dict) and _FORMAT_KEY in output:
352 # TYPING: JSONitem has no idea that _FORMAT_KEY is str
353 new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore # pyright: ignore[reportAssignmentType]
354 output["__write_format__"] = new_fmt # type: ignore
355 return output
357 raise ValueError(f"no handler found for object with {type(obj) = }") # pyright: ignore[reportAny]
359 except Exception as e:
360 if self.error_mode == ErrorMode.EXCEPT:
361 obj_str: str = repr(obj) # pyright: ignore[reportAny]
362 if len(obj_str) > 1000:
363 obj_str = obj_str[:1000] + "..."
364 handler_uid = handler.uid if handler else "no handler matched"
365 raise SerializationException(
366 f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}"
367 ) from e
368 elif self.error_mode == ErrorMode.WARN:
369 warnings.warn(
370 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
371 )
373 return repr(obj) # pyright: ignore[reportAny]
375 def hashify(
376 self,
377 obj: Any, # pyright: ignore[reportAny]
378 path: ObjectPath = (),
379 force: bool = True,
380 ) -> Hashableitem:
381 """try to turn any object into something hashable"""
382 data = self.json_serialize(obj, path=path)
384 # recursive hashify, turning dicts and lists into tuples
385 return _recursive_hashify(data, force=force)
388GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer()
391@overload
392def json_serialize(obj: Mapping[str, Any], path: ObjectPath = ()) -> JSONdict: ...
393@overload
394def json_serialize(obj: list, path: ObjectPath = ()) -> list: ...
395@overload # pyright: ignore[reportOverlappingOverload]
396# def json_serialize(obj: set, path: ObjectPath = ()) -> _SerializedSet: ...
397# @overload
398# def json_serialize(obj: frozenset, path: ObjectPath = ()) -> _SerializedFrozenset: ...
399@overload
400def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: ...
401def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: # pyright: ignore[reportAny]
402 """serialize object to json-serializable object with default config"""
403 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)