Coverage for muutils/json_serialize/json_serialize.py: 28%
64 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""provides the basic framework for json serialization of objects
3notably:
5- `SerializerHandler` defines how to serialize a specific type of object
6- `JsonSerializer` handles configuration for which handlers to use
7- `json_serialize` provides the default configuration if you don't care -- call it on any object!
9"""
11from __future__ import annotations
13import inspect
14import warnings
15from dataclasses import dataclass, is_dataclass
16from pathlib import Path
17from typing import Any, Callable, Iterable, Mapping, Set, Union
19from muutils.errormode import ErrorMode
21try:
22 from muutils.json_serialize.array import ArrayMode, serialize_array
23except ImportError as e:
24 ArrayMode = str # type: ignore[misc]
25 serialize_array = lambda *args, **kwargs: None # noqa: E731
26 warnings.warn(
27 f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}",
28 ImportWarning,
29 )
31from muutils.json_serialize.util import (
32 _FORMAT_KEY,
33 Hashableitem,
34 JSONitem,
35 MonoTuple,
36 SerializationException,
37 _recursive_hashify,
38 isinstance_namedtuple,
39 safe_getsource,
40 string_as_lines,
41 try_catch,
42)
44# pylint: disable=protected-access
46SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = (
47 "__name__",
48 "__doc__",
49 "__module__",
50 "__class__",
51 "__dict__",
52 "__annotations__",
53)
55SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = {
56 "str": str,
57 "dir": dir,
58 "type": try_catch(lambda x: str(type(x).__name__)),
59 "repr": try_catch(lambda x: repr(x)),
60 "code": try_catch(lambda x: inspect.getsource(x)),
61 "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)),
62}
64SERIALIZE_DIRECT_AS_STR: Set[str] = {
65 "<class 'torch.device'>",
66 "<class 'torch.dtype'>",
67}
69ObjectPath = MonoTuple[Union[str, int]]
72@dataclass
73class SerializerHandler:
74 """a handler for a specific type of object
76 # Parameters:
77 - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
78 - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
79 - `desc : str` description of the handler (optional)
80 """
82 # (self_config, object) -> whether to use this handler
83 check: Callable[["JsonSerializer", Any, ObjectPath], bool]
84 # (self_config, object, path) -> serialized object
85 serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem]
86 # unique identifier for the handler
87 uid: str
88 # description of this serializer
89 desc: str
91 def serialize(self) -> dict:
92 """serialize the handler info"""
93 return {
94 # get the code and doc of the check function
95 "check": {
96 "code": safe_getsource(self.check),
97 "doc": string_as_lines(self.check.__doc__),
98 },
99 # get the code and doc of the load function
100 "serialize_func": {
101 "code": safe_getsource(self.serialize_func),
102 "doc": string_as_lines(self.serialize_func.__doc__),
103 },
104 # get the uid, source_pckg, priority, and desc
105 "uid": str(self.uid),
106 "source_pckg": getattr(self.serialize_func, "source_pckg", None),
107 "__module__": getattr(self.serialize_func, "__module__", None),
108 "desc": str(self.desc),
109 }
112BASE_HANDLERS: MonoTuple[SerializerHandler] = (
113 SerializerHandler(
114 check=lambda self, obj, path: isinstance(
115 obj, (bool, int, float, str, type(None))
116 ),
117 serialize_func=lambda self, obj, path: obj,
118 uid="base types",
119 desc="base types (bool, int, float, str, None)",
120 ),
121 SerializerHandler(
122 check=lambda self, obj, path: isinstance(obj, Mapping),
123 serialize_func=lambda self, obj, path: {
124 str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items()
125 },
126 uid="dictionaries",
127 desc="dictionaries",
128 ),
129 SerializerHandler(
130 check=lambda self, obj, path: isinstance(obj, (list, tuple)),
131 serialize_func=lambda self, obj, path: [
132 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
133 ],
134 uid="(list, tuple) -> list",
135 desc="lists and tuples as lists",
136 ),
137)
140def _serialize_override_serialize_func(
141 self: "JsonSerializer", obj: Any, path: ObjectPath
142) -> JSONitem:
143 # obj_cls: type = type(obj)
144 # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self):
145 # obj_cls._register_self()
147 # get the serialized object
148 return obj.serialize()
151DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + (
152 SerializerHandler(
153 # TODO: allow for custom serialization handler name
154 check=lambda self, obj, path: hasattr(obj, "serialize")
155 and callable(obj.serialize),
156 serialize_func=_serialize_override_serialize_func,
157 uid=".serialize override",
158 desc="objects with .serialize method",
159 ),
160 SerializerHandler(
161 check=lambda self, obj, path: isinstance_namedtuple(obj),
162 serialize_func=lambda self, obj, path: self.json_serialize(dict(obj._asdict())),
163 uid="namedtuple -> dict",
164 desc="namedtuples as dicts",
165 ),
166 SerializerHandler(
167 check=lambda self, obj, path: is_dataclass(obj),
168 serialize_func=lambda self, obj, path: {
169 k: self.json_serialize(getattr(obj, k), tuple(path) + (k,))
170 for k in obj.__dataclass_fields__
171 },
172 uid="dataclass -> dict",
173 desc="dataclasses as dicts",
174 ),
175 SerializerHandler(
176 check=lambda self, obj, path: isinstance(obj, Path),
177 serialize_func=lambda self, obj, path: obj.as_posix(),
178 uid="path -> str",
179 desc="Path objects as posix strings",
180 ),
181 SerializerHandler(
182 check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR,
183 serialize_func=lambda self, obj, path: str(obj),
184 uid="obj -> str(obj)",
185 desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings",
186 ),
187 SerializerHandler(
188 check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>",
189 serialize_func=lambda self, obj, path: serialize_array(self, obj, path=path),
190 uid="numpy.ndarray",
191 desc="numpy arrays",
192 ),
193 SerializerHandler(
194 check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>",
195 serialize_func=lambda self, obj, path: serialize_array(
196 self, obj.detach().cpu(), path=path
197 ),
198 uid="torch.Tensor",
199 desc="pytorch tensors",
200 ),
201 SerializerHandler(
202 check=lambda self, obj, path: (
203 str(type(obj)) == "<class 'pandas.core.frame.DataFrame'>"
204 ),
205 serialize_func=lambda self, obj, path: {
206 _FORMAT_KEY: "pandas.DataFrame",
207 "columns": obj.columns.tolist(),
208 "data": obj.to_dict(orient="records"),
209 "path": path, # type: ignore
210 },
211 uid="pandas.DataFrame",
212 desc="pandas DataFrames",
213 ),
214 SerializerHandler(
215 check=lambda self, obj, path: isinstance(obj, (set, list, tuple))
216 or isinstance(obj, Iterable),
217 serialize_func=lambda self, obj, path: [
218 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
219 ],
220 uid="(set, list, tuple, Iterable) -> list",
221 desc="sets, lists, tuples, and Iterables as lists",
222 ),
223 SerializerHandler(
224 check=lambda self, obj, path: True,
225 serialize_func=lambda self, obj, path: {
226 **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS},
227 **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()},
228 },
229 uid="fallback",
230 desc="fallback handler -- serialize object attributes and special functions as strings",
231 ),
232)
235class JsonSerializer:
236 """Json serialization class (holds configs)
238 # Parameters:
239 - `array_mode : ArrayMode`
240 how to write arrays
241 (defaults to `"array_list_meta"`)
242 - `error_mode : ErrorMode`
243 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
244 (defaults to `"except"`)
245 - `handlers_pre : MonoTuple[SerializerHandler]`
246 handlers to use before the default handlers
247 (defaults to `tuple()`)
248 - `handlers_default : MonoTuple[SerializerHandler]`
249 default handlers to use
250 (defaults to `DEFAULT_HANDLERS`)
251 - `write_only_format : bool`
252 changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
253 (defaults to `False`)
255 # Raises:
256 - `ValueError`: on init, if `args` is not empty
257 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
259 """
261 def __init__(
262 self,
263 *args,
264 array_mode: ArrayMode = "array_list_meta",
265 error_mode: ErrorMode = ErrorMode.EXCEPT,
266 handlers_pre: MonoTuple[SerializerHandler] = tuple(),
267 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
268 write_only_format: bool = False,
269 ):
270 if len(args) > 0:
271 raise ValueError(
272 f"JsonSerializer takes no positional arguments!\n{args = }"
273 )
275 self.array_mode: ArrayMode = array_mode
276 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
277 self.write_only_format: bool = write_only_format
278 # join up the handlers
279 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
280 handlers_default
281 )
283 def json_serialize(
284 self,
285 obj: Any,
286 path: ObjectPath = tuple(),
287 ) -> JSONitem:
288 try:
289 for handler in self.handlers:
290 if handler.check(self, obj, path):
291 output: JSONitem = handler.serialize_func(self, obj, path)
292 if self.write_only_format:
293 if isinstance(output, dict) and _FORMAT_KEY in output:
294 new_fmt: JSONitem = output.pop(_FORMAT_KEY)
295 output["__write_format__"] = new_fmt
296 return output
298 raise ValueError(f"no handler found for object with {type(obj) = }")
300 except Exception as e:
301 if self.error_mode == "except":
302 obj_str: str = repr(obj)
303 if len(obj_str) > 1000:
304 obj_str = obj_str[:1000] + "..."
305 raise SerializationException(
306 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
307 ) from e
308 elif self.error_mode == "warn":
309 warnings.warn(
310 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
311 )
313 return repr(obj)
315 def hashify(
316 self,
317 obj: Any,
318 path: ObjectPath = tuple(),
319 force: bool = True,
320 ) -> Hashableitem:
321 """try to turn any object into something hashable"""
322 data = self.json_serialize(obj, path=path)
324 # recursive hashify, turning dicts and lists into tuples
325 return _recursive_hashify(data, force=force)
328GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer()
331def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
332 """serialize object to json-serializable object with default config"""
333 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)