Coverage for muutils/json_serialize/serializable_dataclass.py: 56%
256 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""save and load objects to and from json or compatible formats in a recoverable way
3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
6Instead, you define your class:
8```python
9@serializable_dataclass
10class MyClass(SerializableDataclass):
11 a: int
12 b: str
13```
15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
17 >>> my_obj = MyClass(a=1, b="q")
18 >>> s = json.dumps(my_obj.serialize())
19 >>> s
20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
21 >>> read_obj = MyClass.load(json.loads(s))
22 >>> read_obj == my_obj
23 True
25This isn't too impressive on its own, but it gets more useful when you have nested classses,
26or fields that are not json-serializable by default:
28```python
29@serializable_dataclass
30class NestedClass(SerializableDataclass):
31 x: str
32 y: MyClass
33 act_fun: torch.nn.Module = serializable_field(
34 default=torch.nn.ReLU(),
35 serialization_fn=lambda x: str(x),
36 deserialize_fn=lambda x: getattr(torch.nn, x)(),
37 )
38```
40which gives us:
42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
43 >>> s = json.dumps(nc.serialize())
44 >>> s
45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
46 >>> read_nc = NestedClass.load(json.loads(s))
47 >>> read_nc == nc
48 True
50"""
52from __future__ import annotations
54import abc
55import dataclasses
56import functools
57import json
58import sys
59import typing
60import warnings
61from typing import Any, Optional, Type, TypeVar
63from muutils.errormode import ErrorMode
64from muutils.validate_type import validate_type
65from muutils.json_serialize.serializable_field import (
66 SerializableField,
67 serializable_field,
68)
69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq
71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
73# this is quite horrible, but unfortunately mypy fails if we try to assign to `dataclass_transform` directly
74# and every time we try to init a serializable dataclass it says the argument doesnt exist
75try:
76 try:
77 # type ignore here for legacy versions
78 from typing import dataclass_transform # type: ignore[attr-defined]
79 except Exception:
80 from typing_extensions import dataclass_transform
81except Exception:
82 from muutils.json_serialize.dataclass_transform_mock import dataclass_transform
84T = TypeVar("T")
87class CantGetTypeHintsWarning(UserWarning):
88 "special warning for when we can't get type hints"
90 pass
93class ZanjMissingWarning(UserWarning):
94 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
96 pass
99_zanj_loading_needs_import: bool = True
100"flag to keep track of if we have successfully imported ZANJ"
103def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
104 """Register a serializable dataclass with the ZANJ import
106 this allows `ZANJ().read()` to load the class and not just return plain dicts
109 # TODO: there is some duplication here with register_loader_handler
110 """
111 global _zanj_loading_needs_import
113 if _zanj_loading_needs_import:
114 try:
115 from zanj.loading import ( # type: ignore[import]
116 LoaderHandler,
117 register_loader_handler,
118 )
119 except ImportError:
120 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
121 # warnings.warn(
122 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
123 # ZanjMissingWarning,
124 # )
125 return
127 _format: str = f"{cls.__name__}(SerializableDataclass)"
128 lh: LoaderHandler = LoaderHandler(
129 check=lambda json_item, path=None, z=None: ( # type: ignore
130 isinstance(json_item, dict)
131 and _FORMAT_KEY in json_item
132 and json_item[_FORMAT_KEY].startswith(_format)
133 ),
134 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore
135 uid=_format,
136 source_pckg=cls.__module__,
137 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
138 )
140 register_loader_handler(lh)
142 return lh
145_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
146_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
149class FieldIsNotInitOrSerializeWarning(UserWarning):
150 pass
153def SerializableDataclass__validate_field_type(
154 self: SerializableDataclass,
155 field: SerializableField | str,
156 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
157) -> bool:
158 """given a dataclass, check the field matches the type hint
160 this function is written to `SerializableDataclass.validate_field_type`
162 # Parameters:
163 - `self : SerializableDataclass`
164 `SerializableDataclass` instance
165 - `field : SerializableField | str`
166 field to validate, will get from `self.__dataclass_fields__` if an `str`
167 - `on_typecheck_error : ErrorMode`
168 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
169 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
171 # Returns:
172 - `bool`
173 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
174 """
175 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
177 # get field
178 _field: SerializableField
179 if isinstance(field, str):
180 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined]
181 else:
182 _field = field
184 # do nothing case
185 if not _field.assert_type:
186 return True
188 # if field is not `init` or not `serialize`, skip but warn
189 # TODO: how to handle fields which are not `init` or `serialize`?
190 if not _field.init or not _field.serialize:
191 warnings.warn(
192 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
193 FieldIsNotInitOrSerializeWarning,
194 )
195 return True
197 assert isinstance(
198 _field, SerializableField
199 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
201 # get field type hints
202 try:
203 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
204 except KeyError as e:
205 on_typecheck_error.process(
206 (
207 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
208 + f"{get_cls_type_hints(self.__class__) = }\n"
209 + f"Python version is {sys.version_info = }. You can:\n"
210 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n"
211 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
212 + " - use python 3.9.x or higher\n"
213 + " - specify custom type validation function via `custom_typecheck_fn`\n"
214 ),
215 except_cls=TypeError,
216 except_from=e,
217 )
218 return False
220 # get the value
221 value: Any = getattr(self, _field.name)
223 # validate the type
224 try:
225 type_is_valid: bool
226 # validate the type with the default type validator
227 if _field.custom_typecheck_fn is None:
228 type_is_valid = validate_type(value, field_type_hint)
229 # validate the type with a custom type validator
230 else:
231 type_is_valid = _field.custom_typecheck_fn(field_type_hint)
233 return type_is_valid
235 except Exception as e:
236 on_typecheck_error.process(
237 "exception while validating type: "
238 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
239 except_cls=ValueError,
240 except_from=e,
241 )
242 return False
245def SerializableDataclass__validate_fields_types__dict(
246 self: SerializableDataclass,
247 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
248) -> dict[str, bool]:
249 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
251 returns a dict of field names to bools, where the bool is if the field type is valid
252 """
253 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
255 # if except, bundle the exceptions
256 results: dict[str, bool] = dict()
257 exceptions: dict[str, Exception] = dict()
259 # for each field in the class
260 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment]
261 for field in cls_fields:
262 try:
263 results[field.name] = self.validate_field_type(field, on_typecheck_error)
264 except Exception as e:
265 results[field.name] = False
266 exceptions[field.name] = e
268 # figure out what to do with the exceptions
269 if len(exceptions) > 0:
270 on_typecheck_error.process(
271 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
272 + "\n\t"
273 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
274 except_cls=ValueError,
275 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
276 except_from=list(exceptions.values())[0],
277 )
279 return results
282def SerializableDataclass__validate_fields_types(
283 self: SerializableDataclass,
284 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
285) -> bool:
286 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
287 return all(
288 SerializableDataclass__validate_fields_types__dict(
289 self, on_typecheck_error=on_typecheck_error
290 ).values()
291 )
294@dataclass_transform(
295 field_specifiers=(serializable_field, SerializableField),
296)
297class SerializableDataclass(abc.ABC):
298 """Base class for serializable dataclasses
300 only for linting and type checking, still need to call `serializable_dataclass` decorator
302 # Usage:
304 ```python
305 @serializable_dataclass
306 class MyClass(SerializableDataclass):
307 a: int
308 b: str
309 ```
311 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
313 >>> my_obj = MyClass(a=1, b="q")
314 >>> s = json.dumps(my_obj.serialize())
315 >>> s
316 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
317 >>> read_obj = MyClass.load(json.loads(s))
318 >>> read_obj == my_obj
319 True
321 This isn't too impressive on its own, but it gets more useful when you have nested classses,
322 or fields that are not json-serializable by default:
324 ```python
325 @serializable_dataclass
326 class NestedClass(SerializableDataclass):
327 x: str
328 y: MyClass
329 act_fun: torch.nn.Module = serializable_field(
330 default=torch.nn.ReLU(),
331 serialization_fn=lambda x: str(x),
332 deserialize_fn=lambda x: getattr(torch.nn, x)(),
333 )
334 ```
336 which gives us:
338 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
339 >>> s = json.dumps(nc.serialize())
340 >>> s
341 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
342 >>> read_nc = NestedClass.load(json.loads(s))
343 >>> read_nc == nc
344 True
345 """
347 def serialize(self) -> dict[str, Any]:
348 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
349 raise NotImplementedError(
350 f"decorate {self.__class__ = } with `@serializable_dataclass`"
351 )
353 @classmethod
354 def load(cls: Type[T], data: dict[str, Any] | T) -> T:
355 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
356 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
358 def validate_fields_types(
359 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
360 ) -> bool:
361 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
362 return SerializableDataclass__validate_fields_types(
363 self, on_typecheck_error=on_typecheck_error
364 )
366 def validate_field_type(
367 self,
368 field: "SerializableField|str",
369 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
370 ) -> bool:
371 """given a dataclass, check the field matches the type hint"""
372 return SerializableDataclass__validate_field_type(
373 self, field, on_typecheck_error=on_typecheck_error
374 )
376 def __eq__(self, other: Any) -> bool:
377 return dc_eq(self, other)
379 def __hash__(self) -> int:
380 "hashes the json-serialized representation of the class"
381 return hash(json.dumps(self.serialize()))
383 def diff(
384 self, other: "SerializableDataclass", of_serialized: bool = False
385 ) -> dict[str, Any]:
386 """get a rich and recursive diff between two instances of a serializable dataclass
388 ```python
389 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
390 {'b': {'self': 2, 'other': 3}}
391 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
392 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
393 ```
395 # Parameters:
396 - `other : SerializableDataclass`
397 other instance to compare against
398 - `of_serialized : bool`
399 if true, compare serialized data and not raw values
400 (defaults to `False`)
402 # Returns:
403 - `dict[str, Any]`
406 # Raises:
407 - `ValueError` : if the instances are not of the same type
408 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
409 """
410 # match types
411 if type(self) is not type(other):
412 raise ValueError(
413 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
414 )
416 # initialize the diff result
417 diff_result: dict = {}
419 # if they are the same, return the empty diff
420 try:
421 if self == other:
422 return diff_result
423 except Exception:
424 pass
426 # if we are working with serialized data, serialize the instances
427 if of_serialized:
428 ser_self: dict = self.serialize()
429 ser_other: dict = other.serialize()
431 # for each field in the class
432 for field in dataclasses.fields(self): # type: ignore[arg-type]
433 # skip fields that are not for comparison
434 if not field.compare:
435 continue
437 # get values
438 field_name: str = field.name
439 self_value = getattr(self, field_name)
440 other_value = getattr(other, field_name)
442 # if the values are both serializable dataclasses, recurse
443 if isinstance(self_value, SerializableDataclass) and isinstance(
444 other_value, SerializableDataclass
445 ):
446 nested_diff: dict = self_value.diff(
447 other_value, of_serialized=of_serialized
448 )
449 if nested_diff:
450 diff_result[field_name] = nested_diff
451 # only support serializable dataclasses
452 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
453 other_value
454 ):
455 raise ValueError("Non-serializable dataclass is not supported")
456 else:
457 # get the values of either the serialized or the actual values
458 self_value_s = ser_self[field_name] if of_serialized else self_value
459 other_value_s = ser_other[field_name] if of_serialized else other_value
460 # compare the values
461 if not array_safe_eq(self_value_s, other_value_s):
462 diff_result[field_name] = {"self": self_value, "other": other_value}
464 # return the diff result
465 return diff_result
467 def update_from_nested_dict(self, nested_dict: dict[str, Any]):
468 """update the instance from a nested dict, useful for configuration from command line args
470 # Parameters:
471 - `nested_dict : dict[str, Any]`
472 nested dict to update the instance with
473 """
474 for field in dataclasses.fields(self): # type: ignore[arg-type]
475 field_name: str = field.name
476 self_value = getattr(self, field_name)
478 if field_name in nested_dict:
479 if isinstance(self_value, SerializableDataclass):
480 self_value.update_from_nested_dict(nested_dict[field_name])
481 else:
482 setattr(self, field_name, nested_dict[field_name])
484 def __copy__(self) -> "SerializableDataclass":
485 "deep copy by serializing and loading the instance to json"
486 return self.__class__.load(json.loads(json.dumps(self.serialize())))
488 def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
489 "deep copy by serializing and loading the instance to json"
490 return self.__class__.load(json.loads(json.dumps(self.serialize())))
493# cache this so we don't have to keep getting it
494# TODO: are the types hashable? does this even make sense?
495@functools.lru_cache(typed=True)
496def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
497 "cached typing.get_type_hints for a class"
498 return typing.get_type_hints(cls)
501def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
502 "helper function to get type hints for a class"
503 cls_type_hints: dict[str, Any]
504 try:
505 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore
506 if len(cls_type_hints) == 0:
507 cls_type_hints = typing.get_type_hints(cls)
509 if len(cls_type_hints) == 0:
510 raise ValueError(f"empty type hints for {cls.__name__ = }")
511 except (TypeError, NameError, ValueError) as e:
512 raise TypeError(
513 f"Cannot get type hints for {cls = }\n"
514 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
515 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type]
516 + f" {e = }"
517 ) from e
519 return cls_type_hints
522class KWOnlyError(NotImplementedError):
523 "kw-only dataclasses are not supported in python <3.9"
525 pass
528class FieldError(ValueError):
529 "base class for field errors"
531 pass
534class NotSerializableFieldException(FieldError):
535 "field is not a `SerializableField`"
537 pass
540class FieldSerializationError(FieldError):
541 "error while serializing a field"
543 pass
546class FieldLoadingError(FieldError):
547 "error while loading a field"
549 pass
552class FieldTypeMismatchError(FieldError, TypeError):
553 "error when a field type does not match the type hint"
555 pass
558@dataclass_transform(
559 field_specifiers=(serializable_field, SerializableField),
560)
561def serializable_dataclass(
562 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
563 _cls=None, # type: ignore
564 *,
565 init: bool = True,
566 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
567 eq: bool = True,
568 order: bool = False,
569 unsafe_hash: bool = False,
570 frozen: bool = False,
571 properties_to_serialize: Optional[list[str]] = None,
572 register_handler: bool = True,
573 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
574 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
575 methods_no_override: list[str] | None = None,
576 **kwargs,
577):
578 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
580 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
582 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
584 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
586 Examines PEP 526 `__annotations__` to determine fields.
588 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
590 ```python
591 @serializable_dataclass(kw_only=True)
592 class Myclass(SerializableDataclass):
593 a: int
594 b: str
595 ```
596 ```python
597 >>> Myclass(a=1, b="q").serialize()
598 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
599 ```
601 # Parameters:
603 - `_cls : _type_`
604 class to decorate. don't pass this arg, just use this as a decorator
605 (defaults to `None`)
606 - `init : bool`
607 whether to add an `__init__` method
608 *(passed to dataclasses.dataclass)*
609 (defaults to `True`)
610 - `repr : bool`
611 whether to add a `__repr__` method
612 *(passed to dataclasses.dataclass)*
613 (defaults to `True`)
614 - `order : bool`
615 whether to add rich comparison methods
616 *(passed to dataclasses.dataclass)*
617 (defaults to `False`)
618 - `unsafe_hash : bool`
619 whether to add a `__hash__` method
620 *(passed to dataclasses.dataclass)*
621 (defaults to `False`)
622 - `frozen : bool`
623 whether to make the class frozen
624 *(passed to dataclasses.dataclass)*
625 (defaults to `False`)
626 - `properties_to_serialize : Optional[list[str]]`
627 which properties to add to the serialized data dict
628 **SerializableDataclass only**
629 (defaults to `None`)
630 - `register_handler : bool`
631 if true, register the class with ZANJ for loading
632 **SerializableDataclass only**
633 (defaults to `True`)
634 - `on_typecheck_error : ErrorMode`
635 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
636 **SerializableDataclass only**
637 - `on_typecheck_mismatch : ErrorMode`
638 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
639 **SerializableDataclass only**
640 - `methods_no_override : list[str]|None`
641 list of methods that should not be overridden by the decorator
642 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
643 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
644 **SerializableDataclass only**
645 (defaults to `None`)
646 - `**kwargs`
647 *(passed to dataclasses.dataclass)*
649 # Returns:
651 - `_type_`
652 the decorated class
654 # Raises:
656 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
657 - `NotSerializableFieldException` : if a field is not a `SerializableField`
658 - `FieldSerializationError` : if there is an error serializing a field
659 - `AttributeError` : if a property is not found on the class
660 - `FieldLoadingError` : if there is an error loading a field
661 """
662 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
663 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
664 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
666 if properties_to_serialize is None:
667 _properties_to_serialize: list = list()
668 else:
669 _properties_to_serialize = properties_to_serialize
671 def wrap(cls: Type[T]) -> Type[T]:
672 # Modify the __annotations__ dictionary to replace regular fields with SerializableField
673 for field_name, field_type in cls.__annotations__.items():
674 field_value = getattr(cls, field_name, None)
675 if not isinstance(field_value, SerializableField):
676 if isinstance(field_value, dataclasses.Field):
677 # Convert the field to a SerializableField while preserving properties
678 field_value = SerializableField.from_Field(field_value)
679 else:
680 # Create a new SerializableField
681 field_value = serializable_field()
682 setattr(cls, field_name, field_value)
684 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
685 if sys.version_info < (3, 10):
686 if "kw_only" in kwargs:
687 if kwargs["kw_only"] == True: # noqa: E712
688 raise KWOnlyError(
689 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
690 )
691 else:
692 del kwargs["kw_only"]
694 # call `dataclasses.dataclass` to set some stuff up
695 cls = dataclasses.dataclass( # type: ignore[call-overload]
696 cls,
697 init=init,
698 repr=repr,
699 eq=eq,
700 order=order,
701 unsafe_hash=unsafe_hash,
702 frozen=frozen,
703 **kwargs,
704 )
706 # copy these to the class
707 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined]
709 # ======================================================================
710 # define `serialize` func
711 # done locally since it depends on args to the decorator
712 # ======================================================================
713 def serialize(self) -> dict[str, Any]:
714 result: dict[str, Any] = {
715 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
716 }
717 # for each field in the class
718 for field in dataclasses.fields(self): # type: ignore[arg-type]
719 # need it to be our special SerializableField
720 if not isinstance(field, SerializableField):
721 raise NotSerializableFieldException(
722 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
723 f"but a {type(field)} "
724 "this state should be inaccessible, please report this bug!"
725 )
727 # try to save it
728 if field.serialize:
729 try:
730 # get the val
731 value = getattr(self, field.name)
732 # if it is a serializable dataclass, serialize it
733 if isinstance(value, SerializableDataclass):
734 value = value.serialize()
735 # if the value has a serialization function, use that
736 if hasattr(value, "serialize") and callable(value.serialize):
737 value = value.serialize()
738 # if the field has a serialization function, use that
739 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
740 elif field.serialization_fn:
741 value = field.serialization_fn(value)
743 # store the value in the result
744 result[field.name] = value
745 except Exception as e:
746 raise FieldSerializationError(
747 "\n".join(
748 [
749 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
750 f"{field = }",
751 f"{value = }",
752 f"{self = }",
753 ]
754 )
755 ) from e
757 # store each property if we can get it
758 for prop in self._properties_to_serialize:
759 if hasattr(cls, prop):
760 value = getattr(self, prop)
761 result[prop] = value
762 else:
763 raise AttributeError(
764 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
765 + f"but it is in {self._properties_to_serialize = }"
766 + f"\n{self = }"
767 )
769 return result
771 # ======================================================================
772 # define `load` func
773 # done locally since it depends on args to the decorator
774 # ======================================================================
775 # mypy thinks this isnt a classmethod
776 @classmethod # type: ignore[misc]
777 def load(cls, data: dict[str, Any] | T) -> Type[T]:
778 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
779 if isinstance(data, cls):
780 return data
782 assert isinstance(
783 data, typing.Mapping
784 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788 # initialize dict for keeping what we will pass to the constructor
789 ctor_kwargs: dict[str, Any] = dict()
791 # iterate over the fields of the class
792 for field in dataclasses.fields(cls):
793 # check if the field is a SerializableField
794 assert isinstance(
795 field, SerializableField
796 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798 # check if the field is in the data and if it should be initialized
799 if (field.name in data) and field.init:
800 # get the value, we will be processing it
801 value: Any = data[field.name]
803 # get the type hint for the field
804 field_type_hint: Any = cls_type_hints.get(field.name, None)
806 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
807 if field.deserialize_fn:
808 # if it has a deserialization function, use that
809 value = field.deserialize_fn(value)
810 elif field.loading_fn:
811 # if it has a loading function, use that
812 value = field.loading_fn(data)
813 elif (
814 field_type_hint is not None
815 and hasattr(field_type_hint, "load")
816 and callable(field_type_hint.load)
817 ):
818 # if no loading function but has a type hint with a load method, use that
819 if isinstance(value, dict):
820 value = field_type_hint.load(value)
821 else:
822 raise FieldLoadingError(
823 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
824 )
825 else:
826 # assume no loading needs to happen, keep `value` as-is
827 pass
829 # store the value in the constructor kwargs
830 ctor_kwargs[field.name] = value
832 # create a new instance of the class with the constructor kwargs
833 output: cls = cls(**ctor_kwargs)
835 # validate the types of the fields if needed
836 if on_typecheck_mismatch != ErrorMode.IGNORE:
837 fields_valid: dict[str, bool] = (
838 SerializableDataclass__validate_fields_types__dict(
839 output,
840 on_typecheck_error=on_typecheck_error,
841 )
842 )
844 # if there are any fields that are not valid, raise an error
845 if not all(fields_valid.values()):
846 msg: str = (
847 f"Type mismatch in fields of {cls.__name__}:\n"
848 + "\n".join(
849 [
850 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
851 for k, v in fields_valid.items()
852 if not v
853 ]
854 )
855 )
857 on_typecheck_mismatch.process(
858 msg, except_cls=FieldTypeMismatchError
859 )
861 # return the new instance
862 return output
864 _methods_no_override: set[str]
865 if methods_no_override is None:
866 _methods_no_override = set()
867 else:
868 _methods_no_override = set(methods_no_override)
870 if _methods_no_override - {
871 "__eq__",
872 "serialize",
873 "load",
874 "validate_fields_types",
875 }:
876 warnings.warn(
877 f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
878 )
880 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
881 if "serialize" not in _methods_no_override:
882 # type is `Callable[[T], dict]`
883 cls.serialize = serialize # type: ignore[attr-defined]
884 if "load" not in _methods_no_override:
885 # type is `Callable[[dict], T]`
886 cls.load = load # type: ignore[attr-defined]
888 if "validate_field_type" not in _methods_no_override:
889 # type is `Callable[[T, ErrorMode], bool]`
890 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined]
892 if "__eq__" not in _methods_no_override:
893 # type is `Callable[[T, T], bool]`
894 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment]
896 # Register the class with ZANJ
897 if register_handler:
898 zanj_register_loader_serializable_dataclass(cls)
900 return cls
902 if _cls is None:
903 return wrap
904 else:
905 return wrap(_cls)