Coverage for muutils / json_serialize / serializable_dataclass.py: 56%
261 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1"""save and load objects to and from json or compatible formats in a recoverable way
3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
6Instead, you define your class:
8```python
9@serializable_dataclass
10class MyClass(SerializableDataclass):
11 a: int
12 b: str
13```
15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
17 >>> my_obj = MyClass(a=1, b="q")
18 >>> s = json.dumps(my_obj.serialize())
19 >>> s
20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
21 >>> read_obj = MyClass.load(json.loads(s))
22 >>> read_obj == my_obj
23 True
25This isn't too impressive on its own, but it gets more useful when you have nested classses,
26or fields that are not json-serializable by default:
28```python
29@serializable_dataclass
30class NestedClass(SerializableDataclass):
31 x: str
32 y: MyClass
33 act_fun: torch.nn.Module = serializable_field(
34 default=torch.nn.ReLU(),
35 serialization_fn=lambda x: str(x),
36 deserialize_fn=lambda x: getattr(torch.nn, x)(),
37 )
38```
40which gives us:
42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
43 >>> s = json.dumps(nc.serialize())
44 >>> s
45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
46 >>> read_nc = NestedClass.load(json.loads(s))
47 >>> read_nc == nc
48 True
50"""
52from __future__ import annotations
54import abc
55import dataclasses
56import functools
57import json
58import sys
59import typing
60import warnings
61from typing import Any, Optional, Type, TypeVar, overload, TYPE_CHECKING
63from muutils.errormode import ErrorMode
64from muutils.validate_type import validate_type
65from muutils.json_serialize.serializable_field import (
66 SerializableField,
67 serializable_field,
68)
69from muutils.json_serialize.types import _FORMAT_KEY
70from muutils.json_serialize.util import (
71 JSONdict,
72 array_safe_eq,
73 dc_eq,
74)
76# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
78# For type checkers: always use typing_extensions which they can resolve
79# At runtime: use stdlib if available (3.11+), else typing_extensions, else mock
80if TYPE_CHECKING:
81 from typing_extensions import dataclass_transform, Self
82else:
83 if sys.version_info >= (3, 11):
84 from typing import dataclass_transform, Self
85 else:
86 try:
87 from typing_extensions import dataclass_transform, Self
88 except Exception:
89 from muutils.json_serialize.dataclass_transform_mock import (
90 dataclass_transform,
91 )
93 Self = TypeVar("Self")
95T_SerializeableDataclass = TypeVar(
96 "T_SerializeableDataclass", bound="SerializableDataclass"
97)
100class CantGetTypeHintsWarning(UserWarning):
101 "special warning for when we can't get type hints"
103 pass
106class ZanjMissingWarning(UserWarning):
107 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
109 pass
112_zanj_loading_needs_import: bool = True
113"flag to keep track of if we have successfully imported ZANJ"
116def zanj_register_loader_serializable_dataclass(
117 cls: typing.Type[T_SerializeableDataclass],
118):
119 """Register a serializable dataclass with the ZANJ import
121 this allows `ZANJ().read()` to load the class and not just return plain dicts
124 # TODO: there is some duplication here with register_loader_handler
125 """
126 global _zanj_loading_needs_import
128 if _zanj_loading_needs_import:
129 try:
130 from zanj.loading import ( # type: ignore[import] # pyright: ignore[reportMissingImports]
131 LoaderHandler, # pyright: ignore[reportUnknownVariableType]
132 register_loader_handler, # pyright: ignore[reportUnknownVariableType]
133 )
134 except ImportError:
135 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
136 # warnings.warn(
137 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
138 # ZanjMissingWarning,
139 # )
140 return
142 _format: str = f"{cls.__name__}(SerializableDataclass)"
143 lh: LoaderHandler = LoaderHandler( # pyright: ignore[reportPossiblyUnboundVariable]
144 check=lambda json_item, path=None, z=None: ( # type: ignore
145 isinstance(json_item, dict)
146 and _FORMAT_KEY in json_item
147 and json_item[_FORMAT_KEY].startswith(_format)
148 ),
149 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore
150 uid=_format,
151 source_pckg=cls.__module__,
152 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
153 )
155 register_loader_handler(lh) # pyright: ignore[reportPossiblyUnboundVariable]
157 return lh
160_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
161_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
164class FieldIsNotInitOrSerializeWarning(UserWarning):
165 pass
168def SerializableDataclass__validate_field_type(
169 self: SerializableDataclass,
170 field: SerializableField | str,
171 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
172) -> bool:
173 """given a dataclass, check the field matches the type hint
175 this function is written to `SerializableDataclass.validate_field_type`
177 # Parameters:
178 - `self : SerializableDataclass`
179 `SerializableDataclass` instance
180 - `field : SerializableField | str`
181 field to validate, will get from `self.__dataclass_fields__` if an `str`
182 - `on_typecheck_error : ErrorMode`
183 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
184 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
186 # Returns:
187 - `bool`
188 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
189 """
190 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
192 # get field
193 _field: SerializableField
194 if isinstance(field, str):
195 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined]
196 else:
197 _field = field
199 # do nothing case
200 if not _field.assert_type:
201 return True
203 # if field is not `init` or not `serialize`, skip but warn
204 # TODO: how to handle fields which are not `init` or `serialize`?
205 if not _field.init or not _field.serialize:
206 warnings.warn(
207 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
208 FieldIsNotInitOrSerializeWarning,
209 )
210 return True
212 assert isinstance(_field, SerializableField), (
213 f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
214 )
216 # get field type hints
217 try:
218 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
219 except KeyError as e:
220 on_typecheck_error.process(
221 (
222 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
223 + f"{get_cls_type_hints(self.__class__) = }\n"
224 + f"Python version is {sys.version_info = }. You can:\n"
225 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n"
226 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
227 + " - use python 3.9.x or higher\n"
228 + " - specify custom type validation function via `custom_typecheck_fn`\n"
229 ),
230 except_cls=TypeError,
231 except_from=e,
232 )
233 return False
235 # get the value
236 value: Any = getattr(self, _field.name)
238 # validate the type
239 try:
240 type_is_valid: bool
241 # validate the type with the default type validator
242 if _field.custom_typecheck_fn is None:
243 type_is_valid = validate_type(value, field_type_hint)
244 # validate the type with a custom type validator
245 else:
246 type_is_valid = _field.custom_typecheck_fn(field_type_hint)
248 return type_is_valid
250 except Exception as e:
251 on_typecheck_error.process(
252 "exception while validating type: "
253 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
254 except_cls=ValueError,
255 except_from=e,
256 )
257 return False
260def SerializableDataclass__validate_fields_types__dict(
261 self: SerializableDataclass,
262 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
263) -> dict[str, bool]:
264 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
266 returns a dict of field names to bools, where the bool is if the field type is valid
267 """
268 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
270 # if except, bundle the exceptions
271 results: dict[str, bool] = dict()
272 exceptions: dict[str, Exception] = dict()
274 # for each field in the class
275 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment]
276 for field in cls_fields:
277 try:
278 results[field.name] = self.validate_field_type(field, on_typecheck_error)
279 except Exception as e:
280 results[field.name] = False
281 exceptions[field.name] = e
283 # figure out what to do with the exceptions
284 if len(exceptions) > 0:
285 on_typecheck_error.process(
286 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
287 + "\n\t"
288 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
289 except_cls=ValueError,
290 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
291 except_from=list(exceptions.values())[0],
292 )
294 return results
297def SerializableDataclass__validate_fields_types(
298 self: SerializableDataclass,
299 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
300) -> bool:
301 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
302 return all(
303 SerializableDataclass__validate_fields_types__dict(
304 self, on_typecheck_error=on_typecheck_error
305 ).values()
306 )
309@dataclass_transform(
310 field_specifiers=(serializable_field, SerializableField),
311)
312class SerializableDataclass(abc.ABC):
313 """Base class for serializable dataclasses
315 only for linting and type checking, still need to call `serializable_dataclass` decorator
317 # Usage:
319 ```python
320 @serializable_dataclass
321 class MyClass(SerializableDataclass):
322 a: int
323 b: str
324 ```
326 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
328 >>> my_obj = MyClass(a=1, b="q")
329 >>> s = json.dumps(my_obj.serialize())
330 >>> s
331 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
332 >>> read_obj = MyClass.load(json.loads(s))
333 >>> read_obj == my_obj
334 True
336 This isn't too impressive on its own, but it gets more useful when you have nested classses,
337 or fields that are not json-serializable by default:
339 ```python
340 @serializable_dataclass
341 class NestedClass(SerializableDataclass):
342 x: str
343 y: MyClass
344 act_fun: torch.nn.Module = serializable_field(
345 default=torch.nn.ReLU(),
346 serialization_fn=lambda x: str(x),
347 deserialize_fn=lambda x: getattr(torch.nn, x)(),
348 )
349 ```
351 which gives us:
353 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
354 >>> s = json.dumps(nc.serialize())
355 >>> s
356 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
357 >>> read_nc = NestedClass.load(json.loads(s))
358 >>> read_nc == nc
359 True
360 """
362 def serialize(self) -> dict[str, Any]:
363 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
364 raise NotImplementedError(
365 f"decorate {self.__class__ = } with `@serializable_dataclass`"
366 )
368 @overload
369 @classmethod
370 def load(cls, data: dict[str, Any]) -> Self: ...
372 @overload
373 @classmethod
374 def load(cls, data: Self) -> Self: ...
376 @classmethod
377 def load(cls, data: dict[str, Any] | Self) -> Self:
378 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
379 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
381 def validate_fields_types(
382 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
383 ) -> bool:
384 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
385 return SerializableDataclass__validate_fields_types(
386 self, on_typecheck_error=on_typecheck_error
387 )
389 def validate_field_type(
390 self,
391 field: "SerializableField|str",
392 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
393 ) -> bool:
394 """given a dataclass, check the field matches the type hint"""
395 return SerializableDataclass__validate_field_type(
396 self, field, on_typecheck_error=on_typecheck_error
397 )
399 def __eq__(self, other: Any) -> bool:
400 return dc_eq(self, other)
402 def __hash__(self) -> int:
403 "hashes the json-serialized representation of the class"
404 return hash(json.dumps(self.serialize()))
406 def diff(
407 self, other: "SerializableDataclass", of_serialized: bool = False
408 ) -> dict[str, Any]:
409 """get a rich and recursive diff between two instances of a serializable dataclass
411 ```python
412 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
413 {'b': {'self': 2, 'other': 3}}
414 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
415 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
416 ```
418 # Parameters:
419 - `other : SerializableDataclass`
420 other instance to compare against
421 - `of_serialized : bool`
422 if true, compare serialized data and not raw values
423 (defaults to `False`)
425 # Returns:
426 - `dict[str, Any]`
429 # Raises:
430 - `ValueError` : if the instances are not of the same type
431 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
432 """
433 # match types
434 if type(self) is not type(other):
435 raise ValueError(
436 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
437 )
439 # initialize the diff result
440 diff_result: dict = {}
442 # if they are the same, return the empty diff
443 try:
444 if self == other:
445 return diff_result
446 except Exception:
447 pass
449 # if we are working with serialized data, serialize the instances
450 if of_serialized:
451 ser_self: JSONdict = self.serialize()
452 ser_other: JSONdict = other.serialize()
454 # for each field in the class
455 for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
456 # skip fields that are not for comparison
457 if not field.compare:
458 continue
460 # get values
461 field_name: str = field.name
462 self_value = getattr(self, field_name)
463 other_value = getattr(other, field_name)
465 # if the values are both serializable dataclasses, recurse
466 if isinstance(self_value, SerializableDataclass) and isinstance(
467 other_value, SerializableDataclass
468 ):
469 nested_diff: dict = self_value.diff(
470 other_value, of_serialized=of_serialized
471 )
472 if nested_diff:
473 diff_result[field_name] = nested_diff
474 # only support serializable dataclasses
475 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
476 other_value
477 ):
478 raise ValueError("Non-serializable dataclass is not supported")
479 else:
480 # get the values of either the serialized or the actual values
481 if of_serialized:
482 self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
483 other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
484 else:
485 self_value_s = self_value
486 other_value_s = other_value
487 # compare the values
488 if not array_safe_eq(self_value_s, other_value_s):
489 diff_result[field_name] = {"self": self_value, "other": other_value}
491 # return the diff result
492 return diff_result
494 def update_from_nested_dict(self, nested_dict: dict[str, Any]):
495 """update the instance from a nested dict, useful for configuration from command line args
497 # Parameters:
498 - `nested_dict : dict[str, Any]`
499 nested dict to update the instance with
500 """
501 for field in dataclasses.fields(self): # type: ignore[arg-type]
502 field_name: str = field.name
503 self_value = getattr(self, field_name)
505 if field_name in nested_dict:
506 if isinstance(self_value, SerializableDataclass):
507 self_value.update_from_nested_dict(nested_dict[field_name])
508 else:
509 setattr(self, field_name, nested_dict[field_name])
511 def __copy__(self) -> "SerializableDataclass":
512 "deep copy by serializing and loading the instance to json"
513 return self.__class__.load(json.loads(json.dumps(self.serialize())))
515 def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
516 "deep copy by serializing and loading the instance to json"
517 return self.__class__.load(json.loads(json.dumps(self.serialize())))
520# cache this so we don't have to keep getting it
521# TODO: are the types hashable? does this even make sense?
522@functools.lru_cache(typed=True)
523def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]:
524 "cached typing.get_type_hints for a class"
525 return typing.get_type_hints(cls)
528def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]:
529 "helper function to get type hints for a class"
530 cls_type_hints: dict[str, Any]
531 try:
532 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore
533 if len(cls_type_hints) == 0:
534 cls_type_hints = typing.get_type_hints(cls)
536 if len(cls_type_hints) == 0:
537 raise ValueError(f"empty type hints for {cls.__name__ = }")
538 except (TypeError, NameError, ValueError) as e:
539 raise TypeError(
540 f"Cannot get type hints for {cls = }\n"
541 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
542 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type]
543 + f" {e = }"
544 ) from e
546 return cls_type_hints
549class KWOnlyError(NotImplementedError):
550 "kw-only dataclasses are not supported in python <3.9"
552 pass
555class FieldError(ValueError):
556 "base class for field errors"
558 pass
561class NotSerializableFieldException(FieldError):
562 "field is not a `SerializableField`"
564 pass
567class FieldSerializationError(FieldError):
568 "error while serializing a field"
570 pass
573class FieldLoadingError(FieldError):
574 "error while loading a field"
576 pass
579class FieldTypeMismatchError(FieldError, TypeError):
580 "error when a field type does not match the type hint"
582 pass
585@dataclass_transform(
586 field_specifiers=(serializable_field, SerializableField),
587)
588def serializable_dataclass(
589 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
590 _cls=None, # type: ignore
591 *,
592 init: bool = True,
593 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
594 eq: bool = True,
595 order: bool = False,
596 unsafe_hash: bool = False,
597 frozen: bool = False,
598 properties_to_serialize: Optional[list[str]] = None,
599 register_handler: bool = True,
600 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
601 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
602 methods_no_override: list[str] | None = None,
603 **kwargs: Any,
604) -> Any:
605 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
607 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
609 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
611 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
613 Examines PEP 526 `__annotations__` to determine fields.
615 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
617 ```python
618 @serializable_dataclass(kw_only=True)
619 class Myclass(SerializableDataclass):
620 a: int
621 b: str
622 ```
623 ```python
624 >>> Myclass(a=1, b="q").serialize()
625 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
626 ```
628 # Parameters:
630 - `_cls : _type_`
631 class to decorate. don't pass this arg, just use this as a decorator
632 (defaults to `None`)
633 - `init : bool`
634 whether to add an `__init__` method
635 *(passed to dataclasses.dataclass)*
636 (defaults to `True`)
637 - `repr : bool`
638 whether to add a `__repr__` method
639 *(passed to dataclasses.dataclass)*
640 (defaults to `True`)
641 - `order : bool`
642 whether to add rich comparison methods
643 *(passed to dataclasses.dataclass)*
644 (defaults to `False`)
645 - `unsafe_hash : bool`
646 whether to add a `__hash__` method
647 *(passed to dataclasses.dataclass)*
648 (defaults to `False`)
649 - `frozen : bool`
650 whether to make the class frozen
651 *(passed to dataclasses.dataclass)*
652 (defaults to `False`)
653 - `properties_to_serialize : Optional[list[str]]`
654 which properties to add to the serialized data dict
655 **SerializableDataclass only**
656 (defaults to `None`)
657 - `register_handler : bool`
658 if true, register the class with ZANJ for loading
659 **SerializableDataclass only**
660 (defaults to `True`)
661 - `on_typecheck_error : ErrorMode`
662 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
663 **SerializableDataclass only**
664 - `on_typecheck_mismatch : ErrorMode`
665 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
666 **SerializableDataclass only**
667 - `methods_no_override : list[str]|None`
668 list of methods that should not be overridden by the decorator
669 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
670 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
671 **SerializableDataclass only**
672 (defaults to `None`)
673 - `**kwargs`
674 *(passed to dataclasses.dataclass)*
676 # Returns:
678 - `_type_`
679 the decorated class
681 # Raises:
683 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
684 - `NotSerializableFieldException` : if a field is not a `SerializableField`
685 - `FieldSerializationError` : if there is an error serializing a field
686 - `AttributeError` : if a property is not found on the class
687 - `FieldLoadingError` : if there is an error loading a field
688 """
689 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
690 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
691 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
693 if properties_to_serialize is None:
694 _properties_to_serialize: list = list()
695 else:
696 _properties_to_serialize = properties_to_serialize
698 def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]:
699 # Modify the __annotations__ dictionary to replace regular fields with SerializableField
700 for field_name, field_type in cls.__annotations__.items():
701 field_value = getattr(cls, field_name, None)
702 if not isinstance(field_value, SerializableField):
703 if isinstance(field_value, dataclasses.Field):
704 # Convert the field to a SerializableField while preserving properties
705 field_value = SerializableField.from_Field(field_value)
706 else:
707 # Create a new SerializableField
708 field_value = serializable_field()
709 setattr(cls, field_name, field_value)
711 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
712 if sys.version_info < (3, 10):
713 if "kw_only" in kwargs:
714 if kwargs["kw_only"] == True: # noqa: E712
715 raise KWOnlyError(
716 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
717 )
718 else:
719 del kwargs["kw_only"]
721 # call `dataclasses.dataclass` to set some stuff up
722 cls = dataclasses.dataclass( # type: ignore[call-overload]
723 cls,
724 init=init,
725 repr=repr,
726 eq=eq,
727 order=order,
728 unsafe_hash=unsafe_hash,
729 frozen=frozen,
730 **kwargs,
731 )
733 # copy these to the class
734 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined]
736 # ======================================================================
737 # define `serialize` func
738 # done locally since it depends on args to the decorator
739 # ======================================================================
740 def serialize(self: Any) -> dict[str, Any]:
741 result: dict[str, Any] = {
742 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
743 }
744 # for each field in the class
745 for field in dataclasses.fields(self): # type: ignore[arg-type]
746 # need it to be our special SerializableField
747 if not isinstance(field, SerializableField):
748 raise NotSerializableFieldException(
749 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
750 f"but a {type(field)} "
751 "this state should be inaccessible, please report this bug!"
752 )
754 # try to save it
755 if field.serialize:
756 value: Any = None # init before try in case getattr raises
757 try:
758 # get the val
759 value = getattr(self, field.name)
760 # if it is a serializable dataclass, serialize it
761 if isinstance(value, SerializableDataclass):
762 value = value.serialize()
763 # if the value has a serialization function, use that
764 if hasattr(value, "serialize") and callable(value.serialize): # pyright: ignore[reportAttributeAccessIssue]
765 value = value.serialize() # pyright: ignore[reportAttributeAccessIssue]
766 # if the field has a serialization function, use that
767 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
768 elif field.serialization_fn:
769 value = field.serialization_fn(value)
771 # store the value in the result
772 result[field.name] = value
773 except Exception as e:
774 raise FieldSerializationError(
775 "\n".join(
776 [
777 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
778 f"{field = }",
779 f"{value or '<unavailable>' = }",
780 f"{self = }",
781 ]
782 )
783 ) from e
785 # store each property if we can get it
786 for prop in self._properties_to_serialize:
787 if hasattr(cls, prop):
788 value = getattr(self, prop)
789 result[prop] = value
790 else:
791 raise AttributeError(
792 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
793 + f"but it is in {self._properties_to_serialize = }"
794 + f"\n{self = }"
795 )
797 return result
799 # ======================================================================
800 # define `load` func
801 # done locally since it depends on args to the decorator
802 # ======================================================================
803 # mypy thinks this isnt a classmethod
804 @classmethod # type: ignore[misc]
805 def load(
806 cls: type[T_SerializeableDataclass],
807 data: dict[str, Any] | T_SerializeableDataclass,
808 ) -> T_SerializeableDataclass:
809 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
810 if isinstance(data, cls):
811 return data
813 assert isinstance(data, typing.Mapping), (
814 f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
815 )
817 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
819 # initialize dict for keeping what we will pass to the constructor
820 ctor_kwargs: dict[str, Any] = dict()
822 # iterate over the fields of the class
823 # mypy doesn't recognize @dataclass_transform for dataclasses.fields()
824 # https://github.com/python/mypy/issues/16241
825 for field in dataclasses.fields(cls): # type: ignore[arg-type]
826 # check if the field is a SerializableField
827 assert isinstance(field, SerializableField), (
828 f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
829 )
831 # check if the field is in the data and if it should be initialized
832 if (field.name in data) and field.init:
833 # get the value, we will be processing it
834 value: Any = data[field.name]
836 # get the type hint for the field
837 field_type_hint: Any = cls_type_hints.get(field.name, None)
839 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
840 if field.deserialize_fn:
841 # if it has a deserialization function, use that
842 value = field.deserialize_fn(value)
843 elif field.loading_fn:
844 # if it has a loading function, use that
845 value = field.loading_fn(data)
846 elif (
847 field_type_hint is not None
848 and hasattr(field_type_hint, "load")
849 and callable(field_type_hint.load)
850 ):
851 # if no loading function but has a type hint with a load method, use that
852 if isinstance(value, dict):
853 value = field_type_hint.load(value)
854 else:
855 raise FieldLoadingError(
856 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
857 )
858 else:
859 # assume no loading needs to happen, keep `value` as-is
860 pass
862 # store the value in the constructor kwargs
863 ctor_kwargs[field.name] = value
865 # create a new instance of the class with the constructor kwargs
866 output: T_SerializeableDataclass = cls(**ctor_kwargs)
868 # validate the types of the fields if needed
869 if on_typecheck_mismatch != ErrorMode.IGNORE:
870 fields_valid: dict[str, bool] = (
871 SerializableDataclass__validate_fields_types__dict(
872 output,
873 on_typecheck_error=on_typecheck_error,
874 )
875 )
877 # if there are any fields that are not valid, raise an error
878 if not all(fields_valid.values()):
879 msg: str = (
880 f"Type mismatch in fields of {cls.__name__}:\n"
881 + "\n".join(
882 [
883 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
884 for k, v in fields_valid.items()
885 if not v
886 ]
887 )
888 )
890 on_typecheck_mismatch.process(
891 msg, except_cls=FieldTypeMismatchError
892 )
894 # return the new instance
895 return output
897 _methods_no_override: set[str]
898 if methods_no_override is None:
899 _methods_no_override = set()
900 else:
901 _methods_no_override = set(methods_no_override)
903 if _methods_no_override - {
904 "__eq__",
905 "serialize",
906 "load",
907 "validate_fields_types",
908 }:
909 warnings.warn(
910 f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
911 )
913 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
914 if "serialize" not in _methods_no_override:
915 # type is `Callable[[T], dict]`
916 cls.serialize = serialize # type: ignore[attr-defined, method-assign]
917 if "load" not in _methods_no_override:
918 # type is `Callable[[dict], T]`
919 cls.load = load # type: ignore[attr-defined, method-assign, assignment]
921 if "validate_field_type" not in _methods_no_override:
922 # type is `Callable[[T, ErrorMode], bool]`
923 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined, method-assign]
925 if "__eq__" not in _methods_no_override:
926 # type is `Callable[[T, T], bool]`
927 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment]
929 # Register the class with ZANJ
930 if register_handler:
931 zanj_register_loader_serializable_dataclass(cls)
933 return cls
935 if _cls is None:
936 return wrap
937 else:
938 return wrap(_cls)