Coverage for muutils/json_serialize/serializable_dataclass.py: 56%

256 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""save and load objects to and from json or compatible formats in a recoverable way 

2 

3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 

4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 

5 

6Instead, you define your class: 

7 

8```python 

9@serializable_dataclass 

10class MyClass(SerializableDataclass): 

11 a: int 

12 b: str 

13``` 

14 

15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

16 

17 >>> my_obj = MyClass(a=1, b="q") 

18 >>> s = json.dumps(my_obj.serialize()) 

19 >>> s 

20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

21 >>> read_obj = MyClass.load(json.loads(s)) 

22 >>> read_obj == my_obj 

23 True 

24 

25This isn't too impressive on its own, but it gets more useful when you have nested classses, 

26or fields that are not json-serializable by default: 

27 

28```python 

29@serializable_dataclass 

30class NestedClass(SerializableDataclass): 

31 x: str 

32 y: MyClass 

33 act_fun: torch.nn.Module = serializable_field( 

34 default=torch.nn.ReLU(), 

35 serialization_fn=lambda x: str(x), 

36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

37 ) 

38``` 

39 

40which gives us: 

41 

42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

43 >>> s = json.dumps(nc.serialize()) 

44 >>> s 

45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

46 >>> read_nc = NestedClass.load(json.loads(s)) 

47 >>> read_nc == nc 

48 True 

49 

50""" 

51 

52from __future__ import annotations 

53 

54import abc 

55import dataclasses 

56import functools 

57import json 

58import sys 

59import typing 

60import warnings 

61from typing import Any, Optional, Type, TypeVar 

62 

63from muutils.errormode import ErrorMode 

64from muutils.validate_type import validate_type 

65from muutils.json_serialize.serializable_field import ( 

66 SerializableField, 

67 serializable_field, 

68) 

69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq 

70 

71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 

72 

73# this is quite horrible, but unfortunately mypy fails if we try to assign to `dataclass_transform` directly 

74# and every time we try to init a serializable dataclass it says the argument doesnt exist 

75try: 

76 try: 

77 # type ignore here for legacy versions 

78 from typing import dataclass_transform # type: ignore[attr-defined] 

79 except Exception: 

80 from typing_extensions import dataclass_transform 

81except Exception: 

82 from muutils.json_serialize.dataclass_transform_mock import dataclass_transform 

83 

84T = TypeVar("T") 

85 

86 

87class CantGetTypeHintsWarning(UserWarning): 

88 "special warning for when we can't get type hints" 

89 

90 pass 

91 

92 

93class ZanjMissingWarning(UserWarning): 

94 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 

95 

96 pass 

97 

98 

99_zanj_loading_needs_import: bool = True 

100"flag to keep track of if we have successfully imported ZANJ" 

101 

102 

103def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 

104 """Register a serializable dataclass with the ZANJ import 

105 

106 this allows `ZANJ().read()` to load the class and not just return plain dicts 

107 

108 

109 # TODO: there is some duplication here with register_loader_handler 

110 """ 

111 global _zanj_loading_needs_import 

112 

113 if _zanj_loading_needs_import: 

114 try: 

115 from zanj.loading import ( # type: ignore[import] 

116 LoaderHandler, 

117 register_loader_handler, 

118 ) 

119 except ImportError: 

120 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 

121 # warnings.warn( 

122 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 

123 # ZanjMissingWarning, 

124 # ) 

125 return 

126 

127 _format: str = f"{cls.__name__}(SerializableDataclass)" 

128 lh: LoaderHandler = LoaderHandler( 

129 check=lambda json_item, path=None, z=None: ( # type: ignore 

130 isinstance(json_item, dict) 

131 and _FORMAT_KEY in json_item 

132 and json_item[_FORMAT_KEY].startswith(_format) 

133 ), 

134 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 

135 uid=_format, 

136 source_pckg=cls.__module__, 

137 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 

138 ) 

139 

140 register_loader_handler(lh) 

141 

142 return lh 

143 

144 

145_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 

146_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 

147 

148 

149class FieldIsNotInitOrSerializeWarning(UserWarning): 

150 pass 

151 

152 

153def SerializableDataclass__validate_field_type( 

154 self: SerializableDataclass, 

155 field: SerializableField | str, 

156 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

157) -> bool: 

158 """given a dataclass, check the field matches the type hint 

159 

160 this function is written to `SerializableDataclass.validate_field_type` 

161 

162 # Parameters: 

163 - `self : SerializableDataclass` 

164 `SerializableDataclass` instance 

165 - `field : SerializableField | str` 

166 field to validate, will get from `self.__dataclass_fields__` if an `str` 

167 - `on_typecheck_error : ErrorMode` 

168 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 

169 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 

170 

171 # Returns: 

172 - `bool` 

173 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 

174 """ 

175 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

176 

177 # get field 

178 _field: SerializableField 

179 if isinstance(field, str): 

180 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 

181 else: 

182 _field = field 

183 

184 # do nothing case 

185 if not _field.assert_type: 

186 return True 

187 

188 # if field is not `init` or not `serialize`, skip but warn 

189 # TODO: how to handle fields which are not `init` or `serialize`? 

190 if not _field.init or not _field.serialize: 

191 warnings.warn( 

192 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 

193 FieldIsNotInitOrSerializeWarning, 

194 ) 

195 return True 

196 

197 assert isinstance( 

198 _field, SerializableField 

199 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 

200 

201 # get field type hints 

202 try: 

203 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 

204 except KeyError as e: 

205 on_typecheck_error.process( 

206 ( 

207 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 

208 + f"{get_cls_type_hints(self.__class__) = }\n" 

209 + f"Python version is {sys.version_info = }. You can:\n" 

210 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 

211 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 

212 + " - use python 3.9.x or higher\n" 

213 + " - specify custom type validation function via `custom_typecheck_fn`\n" 

214 ), 

215 except_cls=TypeError, 

216 except_from=e, 

217 ) 

218 return False 

219 

220 # get the value 

221 value: Any = getattr(self, _field.name) 

222 

223 # validate the type 

224 try: 

225 type_is_valid: bool 

226 # validate the type with the default type validator 

227 if _field.custom_typecheck_fn is None: 

228 type_is_valid = validate_type(value, field_type_hint) 

229 # validate the type with a custom type validator 

230 else: 

231 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 

232 

233 return type_is_valid 

234 

235 except Exception as e: 

236 on_typecheck_error.process( 

237 "exception while validating type: " 

238 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 

239 except_cls=ValueError, 

240 except_from=e, 

241 ) 

242 return False 

243 

244 

245def SerializableDataclass__validate_fields_types__dict( 

246 self: SerializableDataclass, 

247 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

248) -> dict[str, bool]: 

249 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 

250 

251 returns a dict of field names to bools, where the bool is if the field type is valid 

252 """ 

253 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

254 

255 # if except, bundle the exceptions 

256 results: dict[str, bool] = dict() 

257 exceptions: dict[str, Exception] = dict() 

258 

259 # for each field in the class 

260 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 

261 for field in cls_fields: 

262 try: 

263 results[field.name] = self.validate_field_type(field, on_typecheck_error) 

264 except Exception as e: 

265 results[field.name] = False 

266 exceptions[field.name] = e 

267 

268 # figure out what to do with the exceptions 

269 if len(exceptions) > 0: 

270 on_typecheck_error.process( 

271 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 

272 + "\n\t" 

273 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 

274 except_cls=ValueError, 

275 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 

276 except_from=list(exceptions.values())[0], 

277 ) 

278 

279 return results 

280 

281 

282def SerializableDataclass__validate_fields_types( 

283 self: SerializableDataclass, 

284 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

285) -> bool: 

286 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

287 return all( 

288 SerializableDataclass__validate_fields_types__dict( 

289 self, on_typecheck_error=on_typecheck_error 

290 ).values() 

291 ) 

292 

293 

294@dataclass_transform( 

295 field_specifiers=(serializable_field, SerializableField), 

296) 

297class SerializableDataclass(abc.ABC): 

298 """Base class for serializable dataclasses 

299 

300 only for linting and type checking, still need to call `serializable_dataclass` decorator 

301 

302 # Usage: 

303 

304 ```python 

305 @serializable_dataclass 

306 class MyClass(SerializableDataclass): 

307 a: int 

308 b: str 

309 ``` 

310 

311 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

312 

313 >>> my_obj = MyClass(a=1, b="q") 

314 >>> s = json.dumps(my_obj.serialize()) 

315 >>> s 

316 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

317 >>> read_obj = MyClass.load(json.loads(s)) 

318 >>> read_obj == my_obj 

319 True 

320 

321 This isn't too impressive on its own, but it gets more useful when you have nested classses, 

322 or fields that are not json-serializable by default: 

323 

324 ```python 

325 @serializable_dataclass 

326 class NestedClass(SerializableDataclass): 

327 x: str 

328 y: MyClass 

329 act_fun: torch.nn.Module = serializable_field( 

330 default=torch.nn.ReLU(), 

331 serialization_fn=lambda x: str(x), 

332 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

333 ) 

334 ``` 

335 

336 which gives us: 

337 

338 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

339 >>> s = json.dumps(nc.serialize()) 

340 >>> s 

341 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

342 >>> read_nc = NestedClass.load(json.loads(s)) 

343 >>> read_nc == nc 

344 True 

345 """ 

346 

347 def serialize(self) -> dict[str, Any]: 

348 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 

349 raise NotImplementedError( 

350 f"decorate {self.__class__ = } with `@serializable_dataclass`" 

351 ) 

352 

353 @classmethod 

354 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 

355 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 

356 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 

357 

358 def validate_fields_types( 

359 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 

360 ) -> bool: 

361 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

362 return SerializableDataclass__validate_fields_types( 

363 self, on_typecheck_error=on_typecheck_error 

364 ) 

365 

366 def validate_field_type( 

367 self, 

368 field: "SerializableField|str", 

369 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

370 ) -> bool: 

371 """given a dataclass, check the field matches the type hint""" 

372 return SerializableDataclass__validate_field_type( 

373 self, field, on_typecheck_error=on_typecheck_error 

374 ) 

375 

376 def __eq__(self, other: Any) -> bool: 

377 return dc_eq(self, other) 

378 

379 def __hash__(self) -> int: 

380 "hashes the json-serialized representation of the class" 

381 return hash(json.dumps(self.serialize())) 

382 

383 def diff( 

384 self, other: "SerializableDataclass", of_serialized: bool = False 

385 ) -> dict[str, Any]: 

386 """get a rich and recursive diff between two instances of a serializable dataclass 

387 

388 ```python 

389 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 

390 {'b': {'self': 2, 'other': 3}} 

391 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 

392 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 

393 ``` 

394 

395 # Parameters: 

396 - `other : SerializableDataclass` 

397 other instance to compare against 

398 - `of_serialized : bool` 

399 if true, compare serialized data and not raw values 

400 (defaults to `False`) 

401 

402 # Returns: 

403 - `dict[str, Any]` 

404 

405 

406 # Raises: 

407 - `ValueError` : if the instances are not of the same type 

408 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 

409 """ 

410 # match types 

411 if type(self) is not type(other): 

412 raise ValueError( 

413 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 

414 ) 

415 

416 # initialize the diff result 

417 diff_result: dict = {} 

418 

419 # if they are the same, return the empty diff 

420 try: 

421 if self == other: 

422 return diff_result 

423 except Exception: 

424 pass 

425 

426 # if we are working with serialized data, serialize the instances 

427 if of_serialized: 

428 ser_self: dict = self.serialize() 

429 ser_other: dict = other.serialize() 

430 

431 # for each field in the class 

432 for field in dataclasses.fields(self): # type: ignore[arg-type] 

433 # skip fields that are not for comparison 

434 if not field.compare: 

435 continue 

436 

437 # get values 

438 field_name: str = field.name 

439 self_value = getattr(self, field_name) 

440 other_value = getattr(other, field_name) 

441 

442 # if the values are both serializable dataclasses, recurse 

443 if isinstance(self_value, SerializableDataclass) and isinstance( 

444 other_value, SerializableDataclass 

445 ): 

446 nested_diff: dict = self_value.diff( 

447 other_value, of_serialized=of_serialized 

448 ) 

449 if nested_diff: 

450 diff_result[field_name] = nested_diff 

451 # only support serializable dataclasses 

452 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 

453 other_value 

454 ): 

455 raise ValueError("Non-serializable dataclass is not supported") 

456 else: 

457 # get the values of either the serialized or the actual values 

458 self_value_s = ser_self[field_name] if of_serialized else self_value 

459 other_value_s = ser_other[field_name] if of_serialized else other_value 

460 # compare the values 

461 if not array_safe_eq(self_value_s, other_value_s): 

462 diff_result[field_name] = {"self": self_value, "other": other_value} 

463 

464 # return the diff result 

465 return diff_result 

466 

467 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 

468 """update the instance from a nested dict, useful for configuration from command line args 

469 

470 # Parameters: 

471 - `nested_dict : dict[str, Any]` 

472 nested dict to update the instance with 

473 """ 

474 for field in dataclasses.fields(self): # type: ignore[arg-type] 

475 field_name: str = field.name 

476 self_value = getattr(self, field_name) 

477 

478 if field_name in nested_dict: 

479 if isinstance(self_value, SerializableDataclass): 

480 self_value.update_from_nested_dict(nested_dict[field_name]) 

481 else: 

482 setattr(self, field_name, nested_dict[field_name]) 

483 

484 def __copy__(self) -> "SerializableDataclass": 

485 "deep copy by serializing and loading the instance to json" 

486 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

487 

488 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 

489 "deep copy by serializing and loading the instance to json" 

490 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

491 

492 

493# cache this so we don't have to keep getting it 

494# TODO: are the types hashable? does this even make sense? 

495@functools.lru_cache(typed=True) 

496def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 

497 "cached typing.get_type_hints for a class" 

498 return typing.get_type_hints(cls) 

499 

500 

501def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 

502 "helper function to get type hints for a class" 

503 cls_type_hints: dict[str, Any] 

504 try: 

505 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 

506 if len(cls_type_hints) == 0: 

507 cls_type_hints = typing.get_type_hints(cls) 

508 

509 if len(cls_type_hints) == 0: 

510 raise ValueError(f"empty type hints for {cls.__name__ = }") 

511 except (TypeError, NameError, ValueError) as e: 

512 raise TypeError( 

513 f"Cannot get type hints for {cls = }\n" 

514 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 

515 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 

516 + f" {e = }" 

517 ) from e 

518 

519 return cls_type_hints 

520 

521 

522class KWOnlyError(NotImplementedError): 

523 "kw-only dataclasses are not supported in python <3.9" 

524 

525 pass 

526 

527 

528class FieldError(ValueError): 

529 "base class for field errors" 

530 

531 pass 

532 

533 

534class NotSerializableFieldException(FieldError): 

535 "field is not a `SerializableField`" 

536 

537 pass 

538 

539 

540class FieldSerializationError(FieldError): 

541 "error while serializing a field" 

542 

543 pass 

544 

545 

546class FieldLoadingError(FieldError): 

547 "error while loading a field" 

548 

549 pass 

550 

551 

552class FieldTypeMismatchError(FieldError, TypeError): 

553 "error when a field type does not match the type hint" 

554 

555 pass 

556 

557 

558@dataclass_transform( 

559 field_specifiers=(serializable_field, SerializableField), 

560) 

561def serializable_dataclass( 

562 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 

563 _cls=None, # type: ignore 

564 *, 

565 init: bool = True, 

566 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 

567 eq: bool = True, 

568 order: bool = False, 

569 unsafe_hash: bool = False, 

570 frozen: bool = False, 

571 properties_to_serialize: Optional[list[str]] = None, 

572 register_handler: bool = True, 

573 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

574 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 

575 methods_no_override: list[str] | None = None, 

576 **kwargs, 

577): 

578 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 

579 

580 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 

581 

582 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 

583 

584 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 

585 

586 Examines PEP 526 `__annotations__` to determine fields. 

587 

588 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 

589 

590 ```python 

591 @serializable_dataclass(kw_only=True) 

592 class Myclass(SerializableDataclass): 

593 a: int 

594 b: str 

595 ``` 

596 ```python 

597 >>> Myclass(a=1, b="q").serialize() 

598 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 

599 ``` 

600 

601 # Parameters: 

602 

603 - `_cls : _type_` 

604 class to decorate. don't pass this arg, just use this as a decorator 

605 (defaults to `None`) 

606 - `init : bool` 

607 whether to add an `__init__` method 

608 *(passed to dataclasses.dataclass)* 

609 (defaults to `True`) 

610 - `repr : bool` 

611 whether to add a `__repr__` method 

612 *(passed to dataclasses.dataclass)* 

613 (defaults to `True`) 

614 - `order : bool` 

615 whether to add rich comparison methods 

616 *(passed to dataclasses.dataclass)* 

617 (defaults to `False`) 

618 - `unsafe_hash : bool` 

619 whether to add a `__hash__` method 

620 *(passed to dataclasses.dataclass)* 

621 (defaults to `False`) 

622 - `frozen : bool` 

623 whether to make the class frozen 

624 *(passed to dataclasses.dataclass)* 

625 (defaults to `False`) 

626 - `properties_to_serialize : Optional[list[str]]` 

627 which properties to add to the serialized data dict 

628 **SerializableDataclass only** 

629 (defaults to `None`) 

630 - `register_handler : bool` 

631 if true, register the class with ZANJ for loading 

632 **SerializableDataclass only** 

633 (defaults to `True`) 

634 - `on_typecheck_error : ErrorMode` 

635 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 

636 **SerializableDataclass only** 

637 - `on_typecheck_mismatch : ErrorMode` 

638 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 

639 **SerializableDataclass only** 

640 - `methods_no_override : list[str]|None` 

641 list of methods that should not be overridden by the decorator 

642 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 

643 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 

644 **SerializableDataclass only** 

645 (defaults to `None`) 

646 - `**kwargs` 

647 *(passed to dataclasses.dataclass)* 

648 

649 # Returns: 

650 

651 - `_type_` 

652 the decorated class 

653 

654 # Raises: 

655 

656 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 

657 - `NotSerializableFieldException` : if a field is not a `SerializableField` 

658 - `FieldSerializationError` : if there is an error serializing a field 

659 - `AttributeError` : if a property is not found on the class 

660 - `FieldLoadingError` : if there is an error loading a field 

661 """ 

662 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 

663 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

664 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 

665 

666 if properties_to_serialize is None: 

667 _properties_to_serialize: list = list() 

668 else: 

669 _properties_to_serialize = properties_to_serialize 

670 

671 def wrap(cls: Type[T]) -> Type[T]: 

672 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 

673 for field_name, field_type in cls.__annotations__.items(): 

674 field_value = getattr(cls, field_name, None) 

675 if not isinstance(field_value, SerializableField): 

676 if isinstance(field_value, dataclasses.Field): 

677 # Convert the field to a SerializableField while preserving properties 

678 field_value = SerializableField.from_Field(field_value) 

679 else: 

680 # Create a new SerializableField 

681 field_value = serializable_field() 

682 setattr(cls, field_name, field_value) 

683 

684 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 

685 if sys.version_info < (3, 10): 

686 if "kw_only" in kwargs: 

687 if kwargs["kw_only"] == True: # noqa: E712 

688 raise KWOnlyError( 

689 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 

690 ) 

691 else: 

692 del kwargs["kw_only"] 

693 

694 # call `dataclasses.dataclass` to set some stuff up 

695 cls = dataclasses.dataclass( # type: ignore[call-overload] 

696 cls, 

697 init=init, 

698 repr=repr, 

699 eq=eq, 

700 order=order, 

701 unsafe_hash=unsafe_hash, 

702 frozen=frozen, 

703 **kwargs, 

704 ) 

705 

706 # copy these to the class 

707 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 

708 

709 # ====================================================================== 

710 # define `serialize` func 

711 # done locally since it depends on args to the decorator 

712 # ====================================================================== 

713 def serialize(self) -> dict[str, Any]: 

714 result: dict[str, Any] = { 

715 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 

716 } 

717 # for each field in the class 

718 for field in dataclasses.fields(self): # type: ignore[arg-type] 

719 # need it to be our special SerializableField 

720 if not isinstance(field, SerializableField): 

721 raise NotSerializableFieldException( 

722 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 

723 f"but a {type(field)} " 

724 "this state should be inaccessible, please report this bug!" 

725 ) 

726 

727 # try to save it 

728 if field.serialize: 

729 try: 

730 # get the val 

731 value = getattr(self, field.name) 

732 # if it is a serializable dataclass, serialize it 

733 if isinstance(value, SerializableDataclass): 

734 value = value.serialize() 

735 # if the value has a serialization function, use that 

736 if hasattr(value, "serialize") and callable(value.serialize): 

737 value = value.serialize() 

738 # if the field has a serialization function, use that 

739 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 

740 elif field.serialization_fn: 

741 value = field.serialization_fn(value) 

742 

743 # store the value in the result 

744 result[field.name] = value 

745 except Exception as e: 

746 raise FieldSerializationError( 

747 "\n".join( 

748 [ 

749 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 

750 f"{field = }", 

751 f"{value = }", 

752 f"{self = }", 

753 ] 

754 ) 

755 ) from e 

756 

757 # store each property if we can get it 

758 for prop in self._properties_to_serialize: 

759 if hasattr(cls, prop): 

760 value = getattr(self, prop) 

761 result[prop] = value 

762 else: 

763 raise AttributeError( 

764 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 

765 + f"but it is in {self._properties_to_serialize = }" 

766 + f"\n{self = }" 

767 ) 

768 

769 return result 

770 

771 # ====================================================================== 

772 # define `load` func 

773 # done locally since it depends on args to the decorator 

774 # ====================================================================== 

775 # mypy thinks this isnt a classmethod 

776 @classmethod # type: ignore[misc] 

777 def load(cls, data: dict[str, Any] | T) -> Type[T]: 

778 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 

779 if isinstance(data, cls): 

780 return data 

781 

782 assert isinstance( 

783 data, typing.Mapping 

784 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 

785 

786 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 

787 

788 # initialize dict for keeping what we will pass to the constructor 

789 ctor_kwargs: dict[str, Any] = dict() 

790 

791 # iterate over the fields of the class 

792 for field in dataclasses.fields(cls): 

793 # check if the field is a SerializableField 

794 assert isinstance( 

795 field, SerializableField 

796 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 

797 

798 # check if the field is in the data and if it should be initialized 

799 if (field.name in data) and field.init: 

800 # get the value, we will be processing it 

801 value: Any = data[field.name] 

802 

803 # get the type hint for the field 

804 field_type_hint: Any = cls_type_hints.get(field.name, None) 

805 

806 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 

807 if field.deserialize_fn: 

808 # if it has a deserialization function, use that 

809 value = field.deserialize_fn(value) 

810 elif field.loading_fn: 

811 # if it has a loading function, use that 

812 value = field.loading_fn(data) 

813 elif ( 

814 field_type_hint is not None 

815 and hasattr(field_type_hint, "load") 

816 and callable(field_type_hint.load) 

817 ): 

818 # if no loading function but has a type hint with a load method, use that 

819 if isinstance(value, dict): 

820 value = field_type_hint.load(value) 

821 else: 

822 raise FieldLoadingError( 

823 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 

824 ) 

825 else: 

826 # assume no loading needs to happen, keep `value` as-is 

827 pass 

828 

829 # store the value in the constructor kwargs 

830 ctor_kwargs[field.name] = value 

831 

832 # create a new instance of the class with the constructor kwargs 

833 output: cls = cls(**ctor_kwargs) 

834 

835 # validate the types of the fields if needed 

836 if on_typecheck_mismatch != ErrorMode.IGNORE: 

837 fields_valid: dict[str, bool] = ( 

838 SerializableDataclass__validate_fields_types__dict( 

839 output, 

840 on_typecheck_error=on_typecheck_error, 

841 ) 

842 ) 

843 

844 # if there are any fields that are not valid, raise an error 

845 if not all(fields_valid.values()): 

846 msg: str = ( 

847 f"Type mismatch in fields of {cls.__name__}:\n" 

848 + "\n".join( 

849 [ 

850 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 

851 for k, v in fields_valid.items() 

852 if not v 

853 ] 

854 ) 

855 ) 

856 

857 on_typecheck_mismatch.process( 

858 msg, except_cls=FieldTypeMismatchError 

859 ) 

860 

861 # return the new instance 

862 return output 

863 

864 _methods_no_override: set[str] 

865 if methods_no_override is None: 

866 _methods_no_override = set() 

867 else: 

868 _methods_no_override = set(methods_no_override) 

869 

870 if _methods_no_override - { 

871 "__eq__", 

872 "serialize", 

873 "load", 

874 "validate_fields_types", 

875 }: 

876 warnings.warn( 

877 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 

878 ) 

879 

880 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 

881 if "serialize" not in _methods_no_override: 

882 # type is `Callable[[T], dict]` 

883 cls.serialize = serialize # type: ignore[attr-defined] 

884 if "load" not in _methods_no_override: 

885 # type is `Callable[[dict], T]` 

886 cls.load = load # type: ignore[attr-defined] 

887 

888 if "validate_field_type" not in _methods_no_override: 

889 # type is `Callable[[T, ErrorMode], bool]` 

890 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 

891 

892 if "__eq__" not in _methods_no_override: 

893 # type is `Callable[[T, T], bool]` 

894 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 

895 

896 # Register the class with ZANJ 

897 if register_handler: 

898 zanj_register_loader_serializable_dataclass(cls) 

899 

900 return cls 

901 

902 if _cls is None: 

903 return wrap 

904 else: 

905 return wrap(_cls)