Coverage for muutils / json_serialize / serializable_dataclass.py: 56%

261 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1"""save and load objects to and from json or compatible formats in a recoverable way 

2 

3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 

4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 

5 

6Instead, you define your class: 

7 

8```python 

9@serializable_dataclass 

10class MyClass(SerializableDataclass): 

11 a: int 

12 b: str 

13``` 

14 

15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

16 

17 >>> my_obj = MyClass(a=1, b="q") 

18 >>> s = json.dumps(my_obj.serialize()) 

19 >>> s 

20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

21 >>> read_obj = MyClass.load(json.loads(s)) 

22 >>> read_obj == my_obj 

23 True 

24 

25This isn't too impressive on its own, but it gets more useful when you have nested classses, 

26or fields that are not json-serializable by default: 

27 

28```python 

29@serializable_dataclass 

30class NestedClass(SerializableDataclass): 

31 x: str 

32 y: MyClass 

33 act_fun: torch.nn.Module = serializable_field( 

34 default=torch.nn.ReLU(), 

35 serialization_fn=lambda x: str(x), 

36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

37 ) 

38``` 

39 

40which gives us: 

41 

42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

43 >>> s = json.dumps(nc.serialize()) 

44 >>> s 

45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

46 >>> read_nc = NestedClass.load(json.loads(s)) 

47 >>> read_nc == nc 

48 True 

49 

50""" 

51 

52from __future__ import annotations 

53 

54import abc 

55import dataclasses 

56import functools 

57import json 

58import sys 

59import typing 

60import warnings 

61from typing import Any, Optional, Type, TypeVar, overload, TYPE_CHECKING 

62 

63from muutils.errormode import ErrorMode 

64from muutils.validate_type import validate_type 

65from muutils.json_serialize.serializable_field import ( 

66 SerializableField, 

67 serializable_field, 

68) 

69from muutils.json_serialize.types import _FORMAT_KEY 

70from muutils.json_serialize.util import ( 

71 JSONdict, 

72 array_safe_eq, 

73 dc_eq, 

74) 

75 

76# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 

77 

78# For type checkers: always use typing_extensions which they can resolve 

79# At runtime: use stdlib if available (3.11+), else typing_extensions, else mock 

80if TYPE_CHECKING: 

81 from typing_extensions import dataclass_transform, Self 

82else: 

83 if sys.version_info >= (3, 11): 

84 from typing import dataclass_transform, Self 

85 else: 

86 try: 

87 from typing_extensions import dataclass_transform, Self 

88 except Exception: 

89 from muutils.json_serialize.dataclass_transform_mock import ( 

90 dataclass_transform, 

91 ) 

92 

93 Self = TypeVar("Self") 

94 

95T_SerializeableDataclass = TypeVar( 

96 "T_SerializeableDataclass", bound="SerializableDataclass" 

97) 

98 

99 

100class CantGetTypeHintsWarning(UserWarning): 

101 "special warning for when we can't get type hints" 

102 

103 pass 

104 

105 

106class ZanjMissingWarning(UserWarning): 

107 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 

108 

109 pass 

110 

111 

112_zanj_loading_needs_import: bool = True 

113"flag to keep track of if we have successfully imported ZANJ" 

114 

115 

116def zanj_register_loader_serializable_dataclass( 

117 cls: typing.Type[T_SerializeableDataclass], 

118): 

119 """Register a serializable dataclass with the ZANJ import 

120 

121 this allows `ZANJ().read()` to load the class and not just return plain dicts 

122 

123 

124 # TODO: there is some duplication here with register_loader_handler 

125 """ 

126 global _zanj_loading_needs_import 

127 

128 if _zanj_loading_needs_import: 

129 try: 

130 from zanj.loading import ( # type: ignore[import] # pyright: ignore[reportMissingImports] 

131 LoaderHandler, # pyright: ignore[reportUnknownVariableType] 

132 register_loader_handler, # pyright: ignore[reportUnknownVariableType] 

133 ) 

134 except ImportError: 

135 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 

136 # warnings.warn( 

137 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 

138 # ZanjMissingWarning, 

139 # ) 

140 return 

141 

142 _format: str = f"{cls.__name__}(SerializableDataclass)" 

143 lh: LoaderHandler = LoaderHandler( # pyright: ignore[reportPossiblyUnboundVariable] 

144 check=lambda json_item, path=None, z=None: ( # type: ignore 

145 isinstance(json_item, dict) 

146 and _FORMAT_KEY in json_item 

147 and json_item[_FORMAT_KEY].startswith(_format) 

148 ), 

149 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 

150 uid=_format, 

151 source_pckg=cls.__module__, 

152 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 

153 ) 

154 

155 register_loader_handler(lh) # pyright: ignore[reportPossiblyUnboundVariable] 

156 

157 return lh 

158 

159 

160_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 

161_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 

162 

163 

164class FieldIsNotInitOrSerializeWarning(UserWarning): 

165 pass 

166 

167 

168def SerializableDataclass__validate_field_type( 

169 self: SerializableDataclass, 

170 field: SerializableField | str, 

171 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

172) -> bool: 

173 """given a dataclass, check the field matches the type hint 

174 

175 this function is written to `SerializableDataclass.validate_field_type` 

176 

177 # Parameters: 

178 - `self : SerializableDataclass` 

179 `SerializableDataclass` instance 

180 - `field : SerializableField | str` 

181 field to validate, will get from `self.__dataclass_fields__` if an `str` 

182 - `on_typecheck_error : ErrorMode` 

183 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 

184 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 

185 

186 # Returns: 

187 - `bool` 

188 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 

189 """ 

190 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

191 

192 # get field 

193 _field: SerializableField 

194 if isinstance(field, str): 

195 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 

196 else: 

197 _field = field 

198 

199 # do nothing case 

200 if not _field.assert_type: 

201 return True 

202 

203 # if field is not `init` or not `serialize`, skip but warn 

204 # TODO: how to handle fields which are not `init` or `serialize`? 

205 if not _field.init or not _field.serialize: 

206 warnings.warn( 

207 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 

208 FieldIsNotInitOrSerializeWarning, 

209 ) 

210 return True 

211 

212 assert isinstance(_field, SerializableField), ( 

213 f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 

214 ) 

215 

216 # get field type hints 

217 try: 

218 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 

219 except KeyError as e: 

220 on_typecheck_error.process( 

221 ( 

222 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 

223 + f"{get_cls_type_hints(self.__class__) = }\n" 

224 + f"Python version is {sys.version_info = }. You can:\n" 

225 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 

226 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 

227 + " - use python 3.9.x or higher\n" 

228 + " - specify custom type validation function via `custom_typecheck_fn`\n" 

229 ), 

230 except_cls=TypeError, 

231 except_from=e, 

232 ) 

233 return False 

234 

235 # get the value 

236 value: Any = getattr(self, _field.name) 

237 

238 # validate the type 

239 try: 

240 type_is_valid: bool 

241 # validate the type with the default type validator 

242 if _field.custom_typecheck_fn is None: 

243 type_is_valid = validate_type(value, field_type_hint) 

244 # validate the type with a custom type validator 

245 else: 

246 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 

247 

248 return type_is_valid 

249 

250 except Exception as e: 

251 on_typecheck_error.process( 

252 "exception while validating type: " 

253 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 

254 except_cls=ValueError, 

255 except_from=e, 

256 ) 

257 return False 

258 

259 

260def SerializableDataclass__validate_fields_types__dict( 

261 self: SerializableDataclass, 

262 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

263) -> dict[str, bool]: 

264 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 

265 

266 returns a dict of field names to bools, where the bool is if the field type is valid 

267 """ 

268 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

269 

270 # if except, bundle the exceptions 

271 results: dict[str, bool] = dict() 

272 exceptions: dict[str, Exception] = dict() 

273 

274 # for each field in the class 

275 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 

276 for field in cls_fields: 

277 try: 

278 results[field.name] = self.validate_field_type(field, on_typecheck_error) 

279 except Exception as e: 

280 results[field.name] = False 

281 exceptions[field.name] = e 

282 

283 # figure out what to do with the exceptions 

284 if len(exceptions) > 0: 

285 on_typecheck_error.process( 

286 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 

287 + "\n\t" 

288 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 

289 except_cls=ValueError, 

290 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 

291 except_from=list(exceptions.values())[0], 

292 ) 

293 

294 return results 

295 

296 

297def SerializableDataclass__validate_fields_types( 

298 self: SerializableDataclass, 

299 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

300) -> bool: 

301 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

302 return all( 

303 SerializableDataclass__validate_fields_types__dict( 

304 self, on_typecheck_error=on_typecheck_error 

305 ).values() 

306 ) 

307 

308 

309@dataclass_transform( 

310 field_specifiers=(serializable_field, SerializableField), 

311) 

312class SerializableDataclass(abc.ABC): 

313 """Base class for serializable dataclasses 

314 

315 only for linting and type checking, still need to call `serializable_dataclass` decorator 

316 

317 # Usage: 

318 

319 ```python 

320 @serializable_dataclass 

321 class MyClass(SerializableDataclass): 

322 a: int 

323 b: str 

324 ``` 

325 

326 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

327 

328 >>> my_obj = MyClass(a=1, b="q") 

329 >>> s = json.dumps(my_obj.serialize()) 

330 >>> s 

331 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

332 >>> read_obj = MyClass.load(json.loads(s)) 

333 >>> read_obj == my_obj 

334 True 

335 

336 This isn't too impressive on its own, but it gets more useful when you have nested classses, 

337 or fields that are not json-serializable by default: 

338 

339 ```python 

340 @serializable_dataclass 

341 class NestedClass(SerializableDataclass): 

342 x: str 

343 y: MyClass 

344 act_fun: torch.nn.Module = serializable_field( 

345 default=torch.nn.ReLU(), 

346 serialization_fn=lambda x: str(x), 

347 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

348 ) 

349 ``` 

350 

351 which gives us: 

352 

353 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

354 >>> s = json.dumps(nc.serialize()) 

355 >>> s 

356 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

357 >>> read_nc = NestedClass.load(json.loads(s)) 

358 >>> read_nc == nc 

359 True 

360 """ 

361 

362 def serialize(self) -> dict[str, Any]: 

363 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 

364 raise NotImplementedError( 

365 f"decorate {self.__class__ = } with `@serializable_dataclass`" 

366 ) 

367 

368 @overload 

369 @classmethod 

370 def load(cls, data: dict[str, Any]) -> Self: ... 

371 

372 @overload 

373 @classmethod 

374 def load(cls, data: Self) -> Self: ... 

375 

376 @classmethod 

377 def load(cls, data: dict[str, Any] | Self) -> Self: 

378 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 

379 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 

380 

381 def validate_fields_types( 

382 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 

383 ) -> bool: 

384 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

385 return SerializableDataclass__validate_fields_types( 

386 self, on_typecheck_error=on_typecheck_error 

387 ) 

388 

389 def validate_field_type( 

390 self, 

391 field: "SerializableField|str", 

392 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

393 ) -> bool: 

394 """given a dataclass, check the field matches the type hint""" 

395 return SerializableDataclass__validate_field_type( 

396 self, field, on_typecheck_error=on_typecheck_error 

397 ) 

398 

399 def __eq__(self, other: Any) -> bool: 

400 return dc_eq(self, other) 

401 

402 def __hash__(self) -> int: 

403 "hashes the json-serialized representation of the class" 

404 return hash(json.dumps(self.serialize())) 

405 

406 def diff( 

407 self, other: "SerializableDataclass", of_serialized: bool = False 

408 ) -> dict[str, Any]: 

409 """get a rich and recursive diff between two instances of a serializable dataclass 

410 

411 ```python 

412 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 

413 {'b': {'self': 2, 'other': 3}} 

414 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 

415 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 

416 ``` 

417 

418 # Parameters: 

419 - `other : SerializableDataclass` 

420 other instance to compare against 

421 - `of_serialized : bool` 

422 if true, compare serialized data and not raw values 

423 (defaults to `False`) 

424 

425 # Returns: 

426 - `dict[str, Any]` 

427 

428 

429 # Raises: 

430 - `ValueError` : if the instances are not of the same type 

431 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 

432 """ 

433 # match types 

434 if type(self) is not type(other): 

435 raise ValueError( 

436 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 

437 ) 

438 

439 # initialize the diff result 

440 diff_result: dict = {} 

441 

442 # if they are the same, return the empty diff 

443 try: 

444 if self == other: 

445 return diff_result 

446 except Exception: 

447 pass 

448 

449 # if we are working with serialized data, serialize the instances 

450 if of_serialized: 

451 ser_self: JSONdict = self.serialize() 

452 ser_other: JSONdict = other.serialize() 

453 

454 # for each field in the class 

455 for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 

456 # skip fields that are not for comparison 

457 if not field.compare: 

458 continue 

459 

460 # get values 

461 field_name: str = field.name 

462 self_value = getattr(self, field_name) 

463 other_value = getattr(other, field_name) 

464 

465 # if the values are both serializable dataclasses, recurse 

466 if isinstance(self_value, SerializableDataclass) and isinstance( 

467 other_value, SerializableDataclass 

468 ): 

469 nested_diff: dict = self_value.diff( 

470 other_value, of_serialized=of_serialized 

471 ) 

472 if nested_diff: 

473 diff_result[field_name] = nested_diff 

474 # only support serializable dataclasses 

475 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 

476 other_value 

477 ): 

478 raise ValueError("Non-serializable dataclass is not supported") 

479 else: 

480 # get the values of either the serialized or the actual values 

481 if of_serialized: 

482 self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 

483 other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 

484 else: 

485 self_value_s = self_value 

486 other_value_s = other_value 

487 # compare the values 

488 if not array_safe_eq(self_value_s, other_value_s): 

489 diff_result[field_name] = {"self": self_value, "other": other_value} 

490 

491 # return the diff result 

492 return diff_result 

493 

494 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 

495 """update the instance from a nested dict, useful for configuration from command line args 

496 

497 # Parameters: 

498 - `nested_dict : dict[str, Any]` 

499 nested dict to update the instance with 

500 """ 

501 for field in dataclasses.fields(self): # type: ignore[arg-type] 

502 field_name: str = field.name 

503 self_value = getattr(self, field_name) 

504 

505 if field_name in nested_dict: 

506 if isinstance(self_value, SerializableDataclass): 

507 self_value.update_from_nested_dict(nested_dict[field_name]) 

508 else: 

509 setattr(self, field_name, nested_dict[field_name]) 

510 

511 def __copy__(self) -> "SerializableDataclass": 

512 "deep copy by serializing and loading the instance to json" 

513 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

514 

515 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 

516 "deep copy by serializing and loading the instance to json" 

517 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

518 

519 

520# cache this so we don't have to keep getting it 

521# TODO: are the types hashable? does this even make sense? 

522@functools.lru_cache(typed=True) 

523def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: 

524 "cached typing.get_type_hints for a class" 

525 return typing.get_type_hints(cls) 

526 

527 

528def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: 

529 "helper function to get type hints for a class" 

530 cls_type_hints: dict[str, Any] 

531 try: 

532 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 

533 if len(cls_type_hints) == 0: 

534 cls_type_hints = typing.get_type_hints(cls) 

535 

536 if len(cls_type_hints) == 0: 

537 raise ValueError(f"empty type hints for {cls.__name__ = }") 

538 except (TypeError, NameError, ValueError) as e: 

539 raise TypeError( 

540 f"Cannot get type hints for {cls = }\n" 

541 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 

542 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 

543 + f" {e = }" 

544 ) from e 

545 

546 return cls_type_hints 

547 

548 

549class KWOnlyError(NotImplementedError): 

550 "kw-only dataclasses are not supported in python <3.9" 

551 

552 pass 

553 

554 

555class FieldError(ValueError): 

556 "base class for field errors" 

557 

558 pass 

559 

560 

561class NotSerializableFieldException(FieldError): 

562 "field is not a `SerializableField`" 

563 

564 pass 

565 

566 

567class FieldSerializationError(FieldError): 

568 "error while serializing a field" 

569 

570 pass 

571 

572 

573class FieldLoadingError(FieldError): 

574 "error while loading a field" 

575 

576 pass 

577 

578 

579class FieldTypeMismatchError(FieldError, TypeError): 

580 "error when a field type does not match the type hint" 

581 

582 pass 

583 

584 

585@dataclass_transform( 

586 field_specifiers=(serializable_field, SerializableField), 

587) 

588def serializable_dataclass( 

589 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 

590 _cls=None, # type: ignore 

591 *, 

592 init: bool = True, 

593 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 

594 eq: bool = True, 

595 order: bool = False, 

596 unsafe_hash: bool = False, 

597 frozen: bool = False, 

598 properties_to_serialize: Optional[list[str]] = None, 

599 register_handler: bool = True, 

600 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

601 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 

602 methods_no_override: list[str] | None = None, 

603 **kwargs: Any, 

604) -> Any: 

605 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 

606 

607 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 

608 

609 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 

610 

611 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 

612 

613 Examines PEP 526 `__annotations__` to determine fields. 

614 

615 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 

616 

617 ```python 

618 @serializable_dataclass(kw_only=True) 

619 class Myclass(SerializableDataclass): 

620 a: int 

621 b: str 

622 ``` 

623 ```python 

624 >>> Myclass(a=1, b="q").serialize() 

625 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 

626 ``` 

627 

628 # Parameters: 

629 

630 - `_cls : _type_` 

631 class to decorate. don't pass this arg, just use this as a decorator 

632 (defaults to `None`) 

633 - `init : bool` 

634 whether to add an `__init__` method 

635 *(passed to dataclasses.dataclass)* 

636 (defaults to `True`) 

637 - `repr : bool` 

638 whether to add a `__repr__` method 

639 *(passed to dataclasses.dataclass)* 

640 (defaults to `True`) 

641 - `order : bool` 

642 whether to add rich comparison methods 

643 *(passed to dataclasses.dataclass)* 

644 (defaults to `False`) 

645 - `unsafe_hash : bool` 

646 whether to add a `__hash__` method 

647 *(passed to dataclasses.dataclass)* 

648 (defaults to `False`) 

649 - `frozen : bool` 

650 whether to make the class frozen 

651 *(passed to dataclasses.dataclass)* 

652 (defaults to `False`) 

653 - `properties_to_serialize : Optional[list[str]]` 

654 which properties to add to the serialized data dict 

655 **SerializableDataclass only** 

656 (defaults to `None`) 

657 - `register_handler : bool` 

658 if true, register the class with ZANJ for loading 

659 **SerializableDataclass only** 

660 (defaults to `True`) 

661 - `on_typecheck_error : ErrorMode` 

662 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 

663 **SerializableDataclass only** 

664 - `on_typecheck_mismatch : ErrorMode` 

665 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 

666 **SerializableDataclass only** 

667 - `methods_no_override : list[str]|None` 

668 list of methods that should not be overridden by the decorator 

669 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 

670 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 

671 **SerializableDataclass only** 

672 (defaults to `None`) 

673 - `**kwargs` 

674 *(passed to dataclasses.dataclass)* 

675 

676 # Returns: 

677 

678 - `_type_` 

679 the decorated class 

680 

681 # Raises: 

682 

683 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 

684 - `NotSerializableFieldException` : if a field is not a `SerializableField` 

685 - `FieldSerializationError` : if there is an error serializing a field 

686 - `AttributeError` : if a property is not found on the class 

687 - `FieldLoadingError` : if there is an error loading a field 

688 """ 

689 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 

690 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

691 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 

692 

693 if properties_to_serialize is None: 

694 _properties_to_serialize: list = list() 

695 else: 

696 _properties_to_serialize = properties_to_serialize 

697 

698 def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]: 

699 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 

700 for field_name, field_type in cls.__annotations__.items(): 

701 field_value = getattr(cls, field_name, None) 

702 if not isinstance(field_value, SerializableField): 

703 if isinstance(field_value, dataclasses.Field): 

704 # Convert the field to a SerializableField while preserving properties 

705 field_value = SerializableField.from_Field(field_value) 

706 else: 

707 # Create a new SerializableField 

708 field_value = serializable_field() 

709 setattr(cls, field_name, field_value) 

710 

711 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 

712 if sys.version_info < (3, 10): 

713 if "kw_only" in kwargs: 

714 if kwargs["kw_only"] == True: # noqa: E712 

715 raise KWOnlyError( 

716 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 

717 ) 

718 else: 

719 del kwargs["kw_only"] 

720 

721 # call `dataclasses.dataclass` to set some stuff up 

722 cls = dataclasses.dataclass( # type: ignore[call-overload] 

723 cls, 

724 init=init, 

725 repr=repr, 

726 eq=eq, 

727 order=order, 

728 unsafe_hash=unsafe_hash, 

729 frozen=frozen, 

730 **kwargs, 

731 ) 

732 

733 # copy these to the class 

734 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 

735 

736 # ====================================================================== 

737 # define `serialize` func 

738 # done locally since it depends on args to the decorator 

739 # ====================================================================== 

740 def serialize(self: Any) -> dict[str, Any]: 

741 result: dict[str, Any] = { 

742 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 

743 } 

744 # for each field in the class 

745 for field in dataclasses.fields(self): # type: ignore[arg-type] 

746 # need it to be our special SerializableField 

747 if not isinstance(field, SerializableField): 

748 raise NotSerializableFieldException( 

749 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 

750 f"but a {type(field)} " 

751 "this state should be inaccessible, please report this bug!" 

752 ) 

753 

754 # try to save it 

755 if field.serialize: 

756 value: Any = None # init before try in case getattr raises 

757 try: 

758 # get the val 

759 value = getattr(self, field.name) 

760 # if it is a serializable dataclass, serialize it 

761 if isinstance(value, SerializableDataclass): 

762 value = value.serialize() 

763 # if the value has a serialization function, use that 

764 if hasattr(value, "serialize") and callable(value.serialize): # pyright: ignore[reportAttributeAccessIssue] 

765 value = value.serialize() # pyright: ignore[reportAttributeAccessIssue] 

766 # if the field has a serialization function, use that 

767 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 

768 elif field.serialization_fn: 

769 value = field.serialization_fn(value) 

770 

771 # store the value in the result 

772 result[field.name] = value 

773 except Exception as e: 

774 raise FieldSerializationError( 

775 "\n".join( 

776 [ 

777 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 

778 f"{field = }", 

779 f"{value or '<unavailable>' = }", 

780 f"{self = }", 

781 ] 

782 ) 

783 ) from e 

784 

785 # store each property if we can get it 

786 for prop in self._properties_to_serialize: 

787 if hasattr(cls, prop): 

788 value = getattr(self, prop) 

789 result[prop] = value 

790 else: 

791 raise AttributeError( 

792 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 

793 + f"but it is in {self._properties_to_serialize = }" 

794 + f"\n{self = }" 

795 ) 

796 

797 return result 

798 

799 # ====================================================================== 

800 # define `load` func 

801 # done locally since it depends on args to the decorator 

802 # ====================================================================== 

803 # mypy thinks this isnt a classmethod 

804 @classmethod # type: ignore[misc] 

805 def load( 

806 cls: type[T_SerializeableDataclass], 

807 data: dict[str, Any] | T_SerializeableDataclass, 

808 ) -> T_SerializeableDataclass: 

809 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 

810 if isinstance(data, cls): 

811 return data 

812 

813 assert isinstance(data, typing.Mapping), ( 

814 f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 

815 ) 

816 

817 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 

818 

819 # initialize dict for keeping what we will pass to the constructor 

820 ctor_kwargs: dict[str, Any] = dict() 

821 

822 # iterate over the fields of the class 

823 # mypy doesn't recognize @dataclass_transform for dataclasses.fields() 

824 # https://github.com/python/mypy/issues/16241 

825 for field in dataclasses.fields(cls): # type: ignore[arg-type] 

826 # check if the field is a SerializableField 

827 assert isinstance(field, SerializableField), ( 

828 f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 

829 ) 

830 

831 # check if the field is in the data and if it should be initialized 

832 if (field.name in data) and field.init: 

833 # get the value, we will be processing it 

834 value: Any = data[field.name] 

835 

836 # get the type hint for the field 

837 field_type_hint: Any = cls_type_hints.get(field.name, None) 

838 

839 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 

840 if field.deserialize_fn: 

841 # if it has a deserialization function, use that 

842 value = field.deserialize_fn(value) 

843 elif field.loading_fn: 

844 # if it has a loading function, use that 

845 value = field.loading_fn(data) 

846 elif ( 

847 field_type_hint is not None 

848 and hasattr(field_type_hint, "load") 

849 and callable(field_type_hint.load) 

850 ): 

851 # if no loading function but has a type hint with a load method, use that 

852 if isinstance(value, dict): 

853 value = field_type_hint.load(value) 

854 else: 

855 raise FieldLoadingError( 

856 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 

857 ) 

858 else: 

859 # assume no loading needs to happen, keep `value` as-is 

860 pass 

861 

862 # store the value in the constructor kwargs 

863 ctor_kwargs[field.name] = value 

864 

865 # create a new instance of the class with the constructor kwargs 

866 output: T_SerializeableDataclass = cls(**ctor_kwargs) 

867 

868 # validate the types of the fields if needed 

869 if on_typecheck_mismatch != ErrorMode.IGNORE: 

870 fields_valid: dict[str, bool] = ( 

871 SerializableDataclass__validate_fields_types__dict( 

872 output, 

873 on_typecheck_error=on_typecheck_error, 

874 ) 

875 ) 

876 

877 # if there are any fields that are not valid, raise an error 

878 if not all(fields_valid.values()): 

879 msg: str = ( 

880 f"Type mismatch in fields of {cls.__name__}:\n" 

881 + "\n".join( 

882 [ 

883 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 

884 for k, v in fields_valid.items() 

885 if not v 

886 ] 

887 ) 

888 ) 

889 

890 on_typecheck_mismatch.process( 

891 msg, except_cls=FieldTypeMismatchError 

892 ) 

893 

894 # return the new instance 

895 return output 

896 

897 _methods_no_override: set[str] 

898 if methods_no_override is None: 

899 _methods_no_override = set() 

900 else: 

901 _methods_no_override = set(methods_no_override) 

902 

903 if _methods_no_override - { 

904 "__eq__", 

905 "serialize", 

906 "load", 

907 "validate_fields_types", 

908 }: 

909 warnings.warn( 

910 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 

911 ) 

912 

913 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 

914 if "serialize" not in _methods_no_override: 

915 # type is `Callable[[T], dict]` 

916 cls.serialize = serialize # type: ignore[attr-defined, method-assign] 

917 if "load" not in _methods_no_override: 

918 # type is `Callable[[dict], T]` 

919 cls.load = load # type: ignore[attr-defined, method-assign, assignment] 

920 

921 if "validate_field_type" not in _methods_no_override: 

922 # type is `Callable[[T, ErrorMode], bool]` 

923 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined, method-assign] 

924 

925 if "__eq__" not in _methods_no_override: 

926 # type is `Callable[[T, T], bool]` 

927 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 

928 

929 # Register the class with ZANJ 

930 if register_handler: 

931 zanj_register_loader_serializable_dataclass(cls) 

932 

933 return cls 

934 

935 if _cls is None: 

936 return wrap 

937 else: 

938 return wrap(_cls)