Coverage for tests / unit / json_serialize / serializable_dataclass / test_serializable_dataclass.py: 89%

533 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 21:32 -0600

1from __future__ import annotations 

2 

3from copy import deepcopy 

4import typing 

5from typing import Any, Dict, Generic, List, Optional, TypeVar, Union 

6 

7import pytest 

8 

9from muutils.errormode import ErrorMode 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from muutils.json_serialize.serializable_dataclass import ( 

17 FieldIsNotInitOrSerializeWarning, 

18 FieldTypeMismatchError, 

19) 

20from muutils.json_serialize.types import _FORMAT_KEY 

21 

22# pylint: disable=missing-class-docstring, unused-variable 

23 

24 

25@serializable_dataclass 

26class BasicAutofields(SerializableDataclass): 

27 a: str 

28 b: int 

29 c: typing.List[int] 

30 

31 

32def test_basic_auto_fields(): 

33 data = dict(a="hello", b=42, c=[1, 2, 3]) 

34 instance = BasicAutofields(**data) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 

35 data_with_format = data.copy() 

36 data_with_format[_FORMAT_KEY] = "BasicAutofields(SerializableDataclass)" 

37 assert instance.serialize() == data_with_format 

38 assert instance == instance 

39 assert instance.diff(instance) == {} 

40 

41 

42def test_basic_diff(): 

43 instance_1 = BasicAutofields(a="hello", b=42, c=[1, 2, 3]) 

44 instance_2 = BasicAutofields(a="goodbye", b=42, c=[1, 2, 3]) 

45 instance_3 = BasicAutofields(a="hello", b=-1, c=[1, 2, 3]) 

46 instance_4 = BasicAutofields(a="hello", b=-1, c=[42]) 

47 

48 assert instance_1.diff(instance_2) == {"a": {"self": "hello", "other": "goodbye"}} 

49 assert instance_1.diff(instance_3) == {"b": {"self": 42, "other": -1}} 

50 assert instance_1.diff(instance_4) == { 

51 "b": {"self": 42, "other": -1}, 

52 "c": {"self": [1, 2, 3], "other": [42]}, 

53 } 

54 assert instance_1.diff(instance_1) == {} 

55 assert instance_2.diff(instance_3) == { 

56 "a": {"self": "goodbye", "other": "hello"}, 

57 "b": {"self": 42, "other": -1}, 

58 } 

59 

60 

61@serializable_dataclass 

62class SimpleFields(SerializableDataclass): 

63 d: str 

64 e: int = 42 

65 f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821 

66 

67 

68@serializable_dataclass 

69class FieldOptions(SerializableDataclass): 

70 a: str = serializable_field() 

71 b: str = serializable_field() 

72 c: str = serializable_field(init=False, serialize=False, repr=False, compare=False) 

73 d: str = serializable_field( 

74 serialization_fn=lambda x: x.upper(), loading_fn=lambda x: x["d"].lower() 

75 ) 

76 

77 

78@serializable_dataclass(properties_to_serialize=["full_name"]) 

79class WithProperty(SerializableDataclass): 

80 first_name: str 

81 last_name: str 

82 

83 @property 

84 def full_name(self) -> str: 

85 return f"{self.first_name} {self.last_name}" 

86 

87 

88class Child(FieldOptions, WithProperty): 

89 pass 

90 

91 

92@pytest.fixture 

93def simple_fields_instance(): 

94 return SimpleFields(d="hello", e=42, f=[1, 2, 3]) 

95 

96 

97@pytest.fixture 

98def field_options_instance(): 

99 return FieldOptions(a="hello", b="world", d="case") 

100 

101 

102@pytest.fixture 

103def with_property_instance(): 

104 return WithProperty(first_name="John", last_name="Doe") 

105 

106 

107def test_simple_fields_serialization(simple_fields_instance): 

108 serialized = simple_fields_instance.serialize() 

109 assert serialized == { 

110 "d": "hello", 

111 "e": 42, 

112 "f": [1, 2, 3], 

113 _FORMAT_KEY: "SimpleFields(SerializableDataclass)", 

114 } 

115 

116 

117def test_simple_fields_loading(simple_fields_instance): 

118 serialized = simple_fields_instance.serialize() 

119 

120 loaded = SimpleFields.load(serialized) 

121 

122 assert loaded == simple_fields_instance 

123 assert loaded.diff(simple_fields_instance) == {} 

124 assert simple_fields_instance.diff(loaded) == {} 

125 

126 

127def test_field_options_serialization(field_options_instance): 

128 serialized = field_options_instance.serialize() 

129 assert serialized == { 

130 "a": "hello", 

131 "b": "world", 

132 "d": "CASE", 

133 _FORMAT_KEY: "FieldOptions(SerializableDataclass)", 

134 } 

135 

136 

137def test_field_options_loading(field_options_instance): 

138 # ignore a `FieldIsNotInitOrSerializeWarning` 

139 serialized = field_options_instance.serialize() 

140 with pytest.warns(FieldIsNotInitOrSerializeWarning): 

141 loaded = FieldOptions.load(serialized) 

142 assert loaded == field_options_instance 

143 

144 

145def test_with_property_serialization(with_property_instance): 

146 serialized = with_property_instance.serialize() 

147 assert serialized == { 

148 "first_name": "John", 

149 "last_name": "Doe", 

150 "full_name": "John Doe", 

151 _FORMAT_KEY: "WithProperty(SerializableDataclass)", 

152 } 

153 

154 

155def test_with_property_loading(with_property_instance): 

156 serialized = with_property_instance.serialize() 

157 loaded = WithProperty.load(serialized) 

158 assert loaded == with_property_instance 

159 

160 

161@serializable_dataclass 

162class Address(SerializableDataclass): 

163 street: str 

164 city: str 

165 zip_code: str 

166 

167 

168@serializable_dataclass 

169class Person(SerializableDataclass): 

170 name: str 

171 age: int 

172 address: Address 

173 

174 

175@pytest.fixture 

176def address_instance(): 

177 return Address(street="123 Main St", city="New York", zip_code="10001") 

178 

179 

180@pytest.fixture 

181def person_instance(address_instance): 

182 return Person(name="John Doe", age=30, address=address_instance) 

183 

184 

185def test_nested_serialization(person_instance): 

186 serialized = person_instance.serialize() 

187 expected_ser = { 

188 "name": "John Doe", 

189 "age": 30, 

190 "address": { 

191 "street": "123 Main St", 

192 "city": "New York", 

193 "zip_code": "10001", 

194 _FORMAT_KEY: "Address(SerializableDataclass)", 

195 }, 

196 _FORMAT_KEY: "Person(SerializableDataclass)", 

197 } 

198 assert serialized == expected_ser 

199 

200 

201def test_nested_loading(person_instance): 

202 serialized = person_instance.serialize() 

203 loaded = Person.load(serialized) 

204 assert loaded == person_instance 

205 assert loaded.address == person_instance.address 

206 

207 

208def test_with_printing(): 

209 @serializable_dataclass(properties_to_serialize=["full_name"]) 

210 class MyClass(SerializableDataclass): 

211 name: str 

212 age: int = serializable_field( 

213 serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1 

214 ) 

215 items: List[Any] = serializable_field(default_factory=list) 

216 

217 @property 

218 def full_name(self) -> str: 

219 return f"{self.name} Doe" 

220 

221 # Usage 

222 my_instance = MyClass(name="John", age=30, items=["apple", "banana"]) 

223 serialized_data = my_instance.serialize() 

224 print(serialized_data) 

225 

226 loaded_instance = MyClass.load(serialized_data) 

227 print(loaded_instance) 

228 

229 

230def test_simple_class_serialization(): 

231 @serializable_dataclass 

232 class SimpleClass(SerializableDataclass): 

233 a: int 

234 b: str 

235 

236 simple = SimpleClass(a=42, b="hello") 

237 serialized = simple.serialize() 

238 assert serialized == { 

239 "a": 42, 

240 "b": "hello", 

241 _FORMAT_KEY: "SimpleClass(SerializableDataclass)", 

242 } 

243 

244 loaded = SimpleClass.load(serialized) 

245 assert loaded == simple 

246 

247 

248def test_error_when_init_and_not_serialize(): 

249 with pytest.raises(ValueError): 

250 

251 @serializable_dataclass 

252 class SimpleClass(SerializableDataclass): 

253 a: int = serializable_field(init=True, serialize=False) 

254 

255 

256def test_person_serialization(): 

257 @serializable_dataclass(properties_to_serialize=["full_name"]) 

258 class FullPerson(SerializableDataclass): 

259 name: str = serializable_field() 

260 age: int = serializable_field(default=-1) 

261 items: typing.List[str] = serializable_field(default_factory=list) 

262 

263 @property 

264 def full_name(self) -> str: 

265 return f"{self.name} Doe" 

266 

267 person = FullPerson(name="John", items=["apple", "banana"]) 

268 serialized = person.serialize() 

269 expected_ser = { 

270 "name": "John", 

271 "age": -1, 

272 "items": ["apple", "banana"], 

273 "full_name": "John Doe", 

274 _FORMAT_KEY: "FullPerson(SerializableDataclass)", 

275 } 

276 assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" 

277 

278 loaded = FullPerson.load(serialized) 

279 

280 assert loaded == person 

281 

282 

283def test_custom_serialization(): 

284 @serializable_dataclass 

285 class CustomSerialization(SerializableDataclass): 

286 data: Any = serializable_field( 

287 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["data"] // 2 

288 ) 

289 

290 custom = CustomSerialization(data=5) 

291 serialized = custom.serialize() 

292 assert serialized == { 

293 "data": 10, 

294 _FORMAT_KEY: "CustomSerialization(SerializableDataclass)", 

295 } 

296 

297 loaded = CustomSerialization.load(serialized) 

298 assert loaded == custom 

299 

300 

301@serializable_dataclass 

302class Nested_with_Container(SerializableDataclass): 

303 val_int: int 

304 val_str: str 

305 val_list: typing.List[BasicAutofields] = serializable_field( 

306 default_factory=list, 

307 serialization_fn=lambda x: [y.serialize() for y in x], 

308 loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]], 

309 ) 

310 

311 

312def test_nested_with_container(): 

313 instance = Nested_with_Container( 

314 val_int=42, 

315 val_str="hello", 

316 val_list=[ 

317 BasicAutofields(a="a", b=1, c=[1, 2, 3]), 

318 BasicAutofields(a="b", b=2, c=[4, 5, 6]), 

319 ], 

320 ) 

321 

322 serialized = instance.serialize() 

323 expected_ser = { 

324 "val_int": 42, 

325 "val_str": "hello", 

326 "val_list": [ 

327 { 

328 "a": "a", 

329 "b": 1, 

330 "c": [1, 2, 3], 

331 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)", 

332 }, 

333 { 

334 "a": "b", 

335 "b": 2, 

336 "c": [4, 5, 6], 

337 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)", 

338 }, 

339 ], 

340 _FORMAT_KEY: "Nested_with_Container(SerializableDataclass)", 

341 } 

342 

343 assert serialized == expected_ser 

344 

345 loaded = Nested_with_Container.load(serialized) 

346 

347 assert loaded == instance 

348 

349 

350class Custom_class_with_serialization: 

351 """custom class which doesnt inherit but does serialize""" 

352 

353 def __init__(self, a: int, b: str): 

354 self.a: int = a 

355 self.b: str = b 

356 

357 def serialize(self): 

358 return {"a": self.a, "b": self.b} 

359 

360 @classmethod 

361 def load(cls, data): 

362 return cls(data["a"], data["b"]) 

363 

364 def __eq__(self, other): 

365 return (self.a == other.a) and (self.b == other.b) 

366 

367 

368@serializable_dataclass 

369class nested_custom(SerializableDataclass): 

370 value: float 

371 data1: Custom_class_with_serialization 

372 

373 

374def test_nested_custom(recwarn): # this will send some warnings but whatever 

375 instance = nested_custom( 

376 value=42.0, data1=Custom_class_with_serialization(1, "hello") 

377 ) 

378 serialized = instance.serialize() 

379 expected_ser = { 

380 "value": 42.0, 

381 "data1": {"a": 1, "b": "hello"}, 

382 _FORMAT_KEY: "nested_custom(SerializableDataclass)", 

383 } 

384 assert serialized == expected_ser 

385 loaded = nested_custom.load(serialized) 

386 assert loaded == instance 

387 

388 

389def test_deserialize_fn(): 

390 @serializable_dataclass 

391 class DeserializeFn(SerializableDataclass): 

392 data: int = serializable_field( 

393 serialization_fn=lambda x: str(x), 

394 deserialize_fn=lambda x: int(x), 

395 ) 

396 

397 instance = DeserializeFn(data=5) 

398 serialized = instance.serialize() 

399 assert serialized == { 

400 "data": "5", 

401 _FORMAT_KEY: "DeserializeFn(SerializableDataclass)", 

402 } 

403 

404 loaded = DeserializeFn.load(serialized) 

405 assert loaded == instance 

406 assert loaded.data == 5 

407 

408 

409@serializable_dataclass 

410class DictContainer(SerializableDataclass): 

411 """Test class containing a dictionary field""" 

412 

413 simple_dict: Dict[str, int] 

414 nested_dict: Dict[str, Dict[str, int]] = serializable_field(default_factory=dict) 

415 optional_dict: Dict[str, str] = serializable_field(default_factory=dict) 

416 

417 

418def test_dict_serialization(): 

419 """Test serialization of dictionaries within SerializableDataclass""" 

420 data = DictContainer( 

421 simple_dict={"a": 1, "b": 2}, 

422 nested_dict={"x": {"y": 3, "z": 4}}, 

423 optional_dict={"hello": "world"}, 

424 ) 

425 

426 serialized = data.serialize() 

427 expected = { 

428 _FORMAT_KEY: "DictContainer(SerializableDataclass)", 

429 "simple_dict": {"a": 1, "b": 2}, 

430 "nested_dict": {"x": {"y": 3, "z": 4}}, 

431 "optional_dict": {"hello": "world"}, 

432 } 

433 

434 assert serialized == expected 

435 

436 

437def test_dict_loading(): 

438 """Test loading dictionaries into SerializableDataclass""" 

439 original_data = { 

440 _FORMAT_KEY: "DictContainer(SerializableDataclass)", 

441 "simple_dict": {"a": 1, "b": 2}, 

442 "nested_dict": {"x": {"y": 3, "z": 4}}, 

443 "optional_dict": {"hello": "world"}, 

444 } 

445 

446 loaded = DictContainer.load(original_data) 

447 assert loaded.simple_dict == {"a": 1, "b": 2} 

448 assert loaded.nested_dict == {"x": {"y": 3, "z": 4}} 

449 assert loaded.optional_dict == {"hello": "world"} 

450 

451 

452def test_dict_equality(): 

453 """Test equality comparison of dictionaries within SerializableDataclass""" 

454 instance1 = DictContainer( 

455 simple_dict={"a": 1, "b": 2}, 

456 nested_dict={"x": {"y": 3, "z": 4}}, 

457 optional_dict={"hello": "world"}, 

458 ) 

459 

460 instance2 = DictContainer( 

461 simple_dict={"a": 1, "b": 2}, 

462 nested_dict={"x": {"y": 3, "z": 4}}, 

463 optional_dict={"hello": "world"}, 

464 ) 

465 

466 instance3 = DictContainer( 

467 simple_dict={"a": 1, "b": 3}, # Different value 

468 nested_dict={"x": {"y": 3, "z": 4}}, 

469 optional_dict={"hello": "world"}, 

470 ) 

471 

472 assert instance1 == instance2 

473 assert instance1 != instance3 

474 assert instance2 != instance3 

475 

476 

477def test_dict_diff(): 

478 """Test diff functionality with dictionaries""" 

479 instance1 = DictContainer( 

480 simple_dict={"a": 1, "b": 2}, 

481 nested_dict={"x": {"y": 3, "z": 4}}, 

482 optional_dict={"hello": "world"}, 

483 ) 

484 

485 # Different simple_dict value 

486 instance2 = DictContainer( 

487 simple_dict={"a": 1, "b": 3}, 

488 nested_dict={"x": {"y": 3, "z": 4}}, 

489 optional_dict={"hello": "world"}, 

490 ) 

491 

492 # Different nested_dict value 

493 instance3 = DictContainer( 

494 simple_dict={"a": 1, "b": 2}, 

495 nested_dict={"x": {"y": 3, "z": 5}}, 

496 optional_dict={"hello": "world"}, 

497 ) 

498 

499 # Different optional_dict value 

500 instance4 = DictContainer( 

501 simple_dict={"a": 1, "b": 2}, 

502 nested_dict={"x": {"y": 3, "z": 4}}, 

503 optional_dict={"hello": "python"}, 

504 ) 

505 

506 # Test diff with simple_dict changes 

507 diff1 = instance1.diff(instance2) 

508 assert diff1 == { 

509 "simple_dict": {"self": {"a": 1, "b": 2}, "other": {"a": 1, "b": 3}} 

510 } 

511 

512 # Test diff with nested_dict changes 

513 diff2 = instance1.diff(instance3) 

514 assert diff2 == { 

515 "nested_dict": { 

516 "self": {"x": {"y": 3, "z": 4}}, 

517 "other": {"x": {"y": 3, "z": 5}}, 

518 } 

519 } 

520 

521 # Test diff with optional_dict changes 

522 diff3 = instance1.diff(instance4) 

523 assert diff3 == { 

524 "optional_dict": {"self": {"hello": "world"}, "other": {"hello": "python"}} 

525 } 

526 

527 # Test no diff when comparing identical instances 

528 assert instance1.diff(instance1) == {} 

529 

530 

531@serializable_dataclass 

532class ComplexDictContainer(SerializableDataclass): 

533 """Test class with more complex dictionary structures""" 

534 

535 mixed_dict: Dict[str, Any] 

536 list_dict: Dict[str, typing.List[int]] 

537 multi_nested: Dict[str, Dict[str, Dict[str, int]]] 

538 

539 

540def test_complex_dict_serialization(): 

541 """Test serialization of more complex dictionary structures""" 

542 data = ComplexDictContainer( 

543 mixed_dict={"str": "hello", "int": 42, "list": [1, 2, 3]}, 

544 list_dict={"a": [1, 2, 3], "b": [4, 5, 6]}, 

545 multi_nested={"x": {"y": {"z": 1, "w": 2}, "v": {"u": 3, "t": 4}}}, 

546 ) 

547 

548 serialized = data.serialize() 

549 loaded = ComplexDictContainer.load(serialized) 

550 assert loaded == data 

551 assert loaded.diff(data) == {} 

552 

553 

554def test_empty_dicts(): 

555 """Test handling of empty dictionaries""" 

556 data = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

557 

558 serialized = data.serialize() 

559 loaded = DictContainer.load(serialized) 

560 assert loaded == data 

561 assert loaded.diff(data) == {} 

562 

563 # Test equality with another empty instance 

564 another_empty = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

565 assert data == another_empty 

566 

567 

568# Test invalid dictionary type validation 

569@serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

570class StrictDictContainer(SerializableDataclass): 

571 """Test class with strict dictionary typing""" 

572 

573 int_dict: Dict[str, int] 

574 str_dict: Dict[str, str] 

575 float_dict: Dict[str, float] 

576 

577 

578# TODO: figure this out 

579@pytest.mark.skip(reason="dict type validation doesnt seem to work") 

580def test_dict_type_validation(): 

581 """Test type validation for dictionary values""" 

582 # Valid case 

583 valid = StrictDictContainer( 

584 int_dict={"a": 1, "b": 2}, 

585 str_dict={"x": "hello", "y": "world"}, 

586 float_dict={"m": 1.0, "n": 2.5}, 

587 ) 

588 assert valid.validate_fields_types() 

589 

590 # Invalid int_dict 

591 with pytest.raises(FieldTypeMismatchError): 

592 StrictDictContainer( 

593 int_dict={"a": "not an int"}, # type: ignore[dict-item] # pyright: ignore[reportArgumentType] 

594 str_dict={"x": "hello"}, 

595 float_dict={"m": 1.0}, 

596 ) 

597 

598 # Invalid str_dict 

599 with pytest.raises(FieldTypeMismatchError): 

600 StrictDictContainer( 

601 int_dict={"a": 1}, 

602 str_dict={"x": 123}, # type: ignore[dict-item] # pyright: ignore[reportArgumentType] 

603 float_dict={"m": 1.0}, 

604 ) 

605 

606 

607# Test dictionary with optional values 

608@serializable_dataclass 

609class OptionalDictContainer(SerializableDataclass): 

610 """Test class with optional dictionary values""" 

611 

612 optional_values: Dict[str, Optional[int]] 

613 union_values: Dict[str, Union[int, str]] 

614 nullable_dict: Optional[Dict[str, int]] = None 

615 

616 

617def test_optional_dict_values(): 

618 """Test dictionaries with optional/union values""" 

619 instance = OptionalDictContainer( 

620 optional_values={"a": 1, "b": None, "c": 3}, 

621 union_values={"x": 1, "y": "string", "z": 42}, 

622 nullable_dict={"m": 1, "n": 2}, 

623 ) 

624 

625 serialized = instance.serialize() 

626 loaded = OptionalDictContainer.load(serialized) 

627 assert loaded == instance 

628 

629 # Test with None dict 

630 instance2 = OptionalDictContainer( 

631 optional_values={"a": None, "b": None}, 

632 union_values={"x": "all strings", "y": "here"}, 

633 nullable_dict=None, 

634 ) 

635 

636 serialized2 = instance2.serialize() 

637 loaded2 = OptionalDictContainer.load(serialized2) 

638 assert loaded2 == instance2 

639 

640 

641# Test dictionary mutation 

642def test_dict_mutation(): 

643 """Test behavior when mutating dictionary contents""" 

644 instance1 = DictContainer( 

645 simple_dict={"a": 1, "b": 2}, 

646 nested_dict={"x": {"y": 3}}, 

647 optional_dict={"hello": "world"}, 

648 ) 

649 

650 instance2 = deepcopy(instance1) 

651 

652 # Mutate dictionary in instance1 

653 instance1.simple_dict["c"] = 3 

654 instance1.nested_dict["x"]["z"] = 4 

655 instance1.optional_dict["new"] = "value" 

656 

657 # Verify instance2 was not affected 

658 assert instance2.simple_dict == {"a": 1, "b": 2} 

659 assert instance2.nested_dict == {"x": {"y": 3}} 

660 assert instance2.optional_dict == {"hello": "world"} 

661 

662 # Verify diff shows the changes 

663 diff = instance2.diff(instance1) 

664 assert "simple_dict" in diff 

665 assert "nested_dict" in diff 

666 assert "optional_dict" in diff 

667 

668 

669# Test dictionary key types 

670@serializable_dataclass 

671class IntKeyDictContainer(SerializableDataclass): 

672 """Test class with non-string dictionary keys""" 

673 

674 int_keys: Dict[int, str] = serializable_field( 

675 serialization_fn=lambda x: {str(k): v for k, v in x.items()}, 

676 loading_fn=lambda x: {int(k): v for k, v in x["int_keys"].items()}, 

677 ) 

678 

679 

680def test_non_string_dict_keys(): 

681 """Test handling of dictionaries with non-string keys""" 

682 instance = IntKeyDictContainer(int_keys={1: "one", 2: "two", 3: "three"}) 

683 

684 serialized = instance.serialize() 

685 # Keys should be converted to strings in serialized form 

686 assert all(isinstance(k, str) for k in serialized["int_keys"].keys()) 

687 

688 loaded = IntKeyDictContainer.load(serialized) 

689 # Keys should be integers again after loading 

690 assert all(isinstance(k, int) for k in loaded.int_keys.keys()) 

691 assert loaded == instance 

692 

693 

694@serializable_dataclass 

695class RecursiveDictContainer(SerializableDataclass): 

696 """Test class with recursively defined dictionary type""" 

697 

698 data: Dict[str, Any] 

699 

700 

701def test_recursive_dict_structure(): 

702 """Test handling of recursively nested dictionaries""" 

703 deep_dict = { 

704 "level1": { 

705 "level2": {"level3": {"value": 42, "list": [1, 2, {"nested": "value"}]}} 

706 } 

707 } 

708 

709 instance = RecursiveDictContainer(data=deep_dict) 

710 serialized = instance.serialize() 

711 loaded = RecursiveDictContainer.load(serialized) 

712 

713 assert loaded == instance 

714 assert loaded.data == deep_dict 

715 

716 

717# need to define this outside, otherwise the validator cant see it? 

718class CustomSerializable: 

719 def __init__(self, value): 

720 self.value: Union[str, int] = value 

721 

722 def serialize(self): 

723 return {"value": self.value} 

724 

725 @classmethod 

726 def load(cls, data): 

727 return cls(data["value"]) 

728 

729 def __eq__(self, other): 

730 return isinstance(other, CustomSerializable) and self.value == other.value 

731 

732 

733def test_dict_with_custom_objects(): 

734 """Test dictionaries containing custom objects that implement serialize/load""" 

735 

736 @serializable_dataclass 

737 class CustomObjectDict(SerializableDataclass): 

738 data: Dict[str, CustomSerializable] = serializable_field() 

739 

740 instance: CustomObjectDict = CustomObjectDict( 

741 data={"a": CustomSerializable(42), "b": CustomSerializable("hello")} 

742 ) 

743 

744 assert isinstance(instance, CustomObjectDict) 

745 assert isinstance(instance.data, dict) 

746 assert isinstance(instance.data["a"], CustomSerializable) 

747 assert isinstance(instance.data["a"].value, int) 

748 assert isinstance(instance.data["b"], CustomSerializable) 

749 assert isinstance(instance.data["b"].value, str) 

750 

751 serialized = instance.serialize() 

752 loaded = CustomObjectDict.load(serialized) 

753 assert loaded == instance 

754 

755 

756def test_empty_optional_dicts(): 

757 """Test handling of None vs empty dict in optional dictionary fields""" 

758 

759 @serializable_dataclass 

760 class OptionalDictFields(SerializableDataclass): 

761 required_dict: Dict[str, int] 

762 optional_dict: Optional[Dict[str, int]] = None 

763 default_empty: Dict[str, int] = serializable_field(default_factory=dict) 

764 

765 # Test with None 

766 instance1 = OptionalDictFields(required_dict={"a": 1}, optional_dict=None) 

767 

768 # Test with empty dict 

769 instance2 = OptionalDictFields(required_dict={"a": 1}, optional_dict={}) 

770 

771 serialized1 = instance1.serialize() 

772 serialized2 = instance2.serialize() 

773 

774 loaded1 = OptionalDictFields.load(serialized1) 

775 loaded2 = OptionalDictFields.load(serialized2) 

776 

777 assert loaded1.optional_dict is None 

778 assert loaded2.optional_dict == {} 

779 assert loaded1.default_empty == {} 

780 assert loaded2.default_empty == {} 

781 

782 

783# Test inheritance hierarchies 

784@serializable_dataclass( 

785 on_typecheck_error=ErrorMode.EXCEPT, on_typecheck_mismatch=ErrorMode.EXCEPT 

786) 

787class BaseClass(SerializableDataclass): 

788 """Base class for testing inheritance""" 

789 

790 base_field: str 

791 shared_field: int = serializable_field(default=0) 

792 

793 

794@serializable_dataclass 

795class ChildClass(BaseClass): 

796 """Child class inheriting from BaseClass""" 

797 

798 child_field: float = serializable_field(default=0.1) 

799 shared_field: int = serializable_field(default=1) # Override base class field 

800 

801 

802@serializable_dataclass 

803class GrandchildClass(ChildClass): 

804 """Grandchild class for deep inheritance testing""" 

805 

806 grandchild_field: bool = serializable_field(default=True) 

807 

808 

809def test_inheritance(): 

810 """Test inheritance behavior of serializable dataclasses""" 

811 instance = GrandchildClass( 

812 base_field="base", shared_field=42, child_field=3.14, grandchild_field=True 

813 ) 

814 

815 serialized = instance.serialize() 

816 assert serialized["base_field"] == "base" 

817 assert serialized["shared_field"] == 42 

818 assert serialized["child_field"] == 3.14 

819 assert serialized["grandchild_field"] is True 

820 

821 loaded = GrandchildClass.load(serialized) 

822 assert loaded == instance 

823 

824 # Test that we can load as parent class 

825 base_loaded = BaseClass.load({"base_field": "test", "shared_field": 1}) 

826 assert isinstance(base_loaded, BaseClass) 

827 assert not isinstance(base_loaded, ChildClass) 

828 

829 

830@pytest.mark.skip( 

831 reason="Not implemented yet, generic types not supported and throw a `TypeHintNotImplementedError`" 

832) 

833def test_generic_types(): 

834 """Test handling of generic type parameters""" 

835 

836 T = TypeVar("T") 

837 

838 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

839 class GenericContainer(SerializableDataclass, Generic[T]): 

840 """Test generic type parameters""" 

841 

842 value: T 

843 values: List[T] 

844 

845 # Test with int 

846 int_container = GenericContainer[int](value=42, values=[1, 2, 3]) 

847 serialized = int_container.serialize() 

848 loaded = GenericContainer[int].load(serialized) 

849 assert loaded == int_container 

850 

851 # Test with str 

852 str_container = GenericContainer[str](value="hello", values=["a", "b", "c"]) 

853 serialized2 = str_container.serialize() 

854 loaded2 = GenericContainer[str].load(serialized2) 

855 assert loaded2 == str_container 

856 

857 

858def test_literal_field_typecheck(): 

859 """Literal fields should validate correctly without suppressing type errors.""" 

860 

861 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

862 class LiteralFieldClass(SerializableDataclass): 

863 mode: typing.Literal["fast", "slow", "auto"] = serializable_field( 

864 default="fast" 

865 ) 

866 count: typing.Literal[1, 2, 3] = serializable_field(default=1) 

867 

868 # valid values — validate_fields_types returns True 

869 obj = LiteralFieldClass(mode="fast", count=2) 

870 assert obj.validate_fields_types() is True 

871 

872 # invalid literal value — validate_fields_types returns False 

873 obj2 = LiteralFieldClass.__new__(LiteralFieldClass) 

874 object.__setattr__(obj2, "mode", "invalid") 

875 object.__setattr__(obj2, "count", 1) 

876 assert obj2.validate_fields_types() is False 

877 

878 # loading invalid literal via load() raises due to on_typecheck_mismatch=EXCEPT 

879 with pytest.raises(Exception): 

880 LiteralFieldClass.load({"mode": "invalid", "count": 1}) 

881 

882 

883def test_literal_roundtrip(): 

884 """Round-trip serialize/load for dataclasses with Literal fields of various kinds.""" 

885 

886 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

887 class LiteralRoundtrip(SerializableDataclass): 

888 mode: typing.Literal["fast", "slow"] = serializable_field(default="fast") 

889 count: typing.Literal[1, 2, 3] = serializable_field(default=1) 

890 tag: typing.Optional[typing.Literal["a", "b"]] = serializable_field( 

891 default=None 

892 ) 

893 items: typing.List[typing.Literal["x", "y", "z"]] = serializable_field( 

894 default_factory=list 

895 ) 

896 scores: typing.Dict[str, typing.Literal[0, 1]] = serializable_field( 

897 default_factory=dict 

898 ) 

899 flags: typing.Dict[typing.Literal["on", "off"], int] = serializable_field( 

900 default_factory=dict 

901 ) 

902 

903 cases = [ 

904 LiteralRoundtrip(), 

905 LiteralRoundtrip(mode="slow", count=2, tag="a"), 

906 LiteralRoundtrip(mode="fast", count=3, tag="b", items=["x", "y", "z"]), 

907 LiteralRoundtrip( 

908 mode="slow", 

909 count=1, 

910 tag=None, 

911 items=["z", "x"], 

912 scores={"win": 1, "loss": 0}, 

913 flags={"on": 100, "off": 0}, 

914 ), 

915 ] 

916 for obj in cases: 

917 serialized = obj.serialize() 

918 loaded = LiteralRoundtrip.load(serialized) 

919 assert loaded == obj 

920 assert loaded.validate_fields_types() is True 

921 

922 # invalid Literal string value raises on load 

923 with pytest.raises(Exception): 

924 LiteralRoundtrip.load({"mode": "turbo", "count": 1}) 

925 

926 # invalid Literal int value raises on load 

927 with pytest.raises(Exception): 

928 LiteralRoundtrip.load({"mode": "fast", "count": 99}) 

929 

930 # invalid list element raises on load 

931 with pytest.raises(Exception): 

932 LiteralRoundtrip.load({"mode": "fast", "count": 1, "items": ["x", "bad"]}) 

933 

934 # invalid dict value raises on load 

935 with pytest.raises(Exception): 

936 LiteralRoundtrip.load({"mode": "fast", "count": 1, "scores": {"win": 99}}) 

937 

938 # invalid dict key (Literal key) — validate_fields_types returns False 

939 obj_bad_key = LiteralRoundtrip.__new__(LiteralRoundtrip) 

940 object.__setattr__(obj_bad_key, "mode", "fast") 

941 object.__setattr__(obj_bad_key, "count", 1) 

942 object.__setattr__(obj_bad_key, "tag", None) 

943 object.__setattr__(obj_bad_key, "items", []) 

944 object.__setattr__(obj_bad_key, "scores", {}) 

945 object.__setattr__(obj_bad_key, "flags", {"maybe": 1}) 

946 assert obj_bad_key.validate_fields_types() is False 

947 

948 

949# Test custom serialization/deserialization 

950class CustomObject: 

951 def __init__(self, value): 

952 self.value = value 

953 

954 def __eq__(self, other): 

955 return isinstance(other, CustomObject) and self.value == other.value 

956 

957 

958@serializable_dataclass 

959class CustomSerializationContainer(SerializableDataclass): 

960 """Test custom serialization functions""" 

961 

962 custom_obj: CustomObject = serializable_field( 

963 serialization_fn=lambda x: x.value, 

964 loading_fn=lambda x: CustomObject(x["custom_obj"]), 

965 ) 

966 transform_field: int = serializable_field( 

967 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["transform_field"] // 2 

968 ) 

969 

970 

971def test_custom_serialization_2(): 

972 """Test custom serialization and loading functions""" 

973 instance = CustomSerializationContainer( 

974 custom_obj=CustomObject(42), transform_field=10 

975 ) 

976 

977 serialized = instance.serialize() 

978 assert serialized["custom_obj"] == 42 

979 assert serialized["transform_field"] == 20 

980 

981 loaded = CustomSerializationContainer.load(serialized) 

982 assert loaded == instance 

983 assert loaded.transform_field == 10 

984 

985 

986# @pytest.mark.skip(reason="Not implemented yet, waiting on `custom_value_check_fn`") 

987# def test_value_validation(): 

988# """Test field validation""" 

989# @serializable_dataclass 

990# class ValidationContainer(SerializableDataclass): 

991# """Test validation and error handling""" 

992# positive_int: int = serializable_field( 

993# custom_value_check_fn=lambda x: x > 0 

994# ) 

995# email: str = serializable_field( 

996# custom_value_check_fn=lambda x: '@' in x 

997# ) 

998 

999# # Valid case 

1000# valid = ValidationContainer(positive_int=42, email="test@example.com") 

1001# assert valid.validate_fields_types() 

1002 

1003# # what will this do? 

1004# maybe_valid = ValidationContainer(positive_int=4.2, email="test@example.com") 

1005# assert maybe_valid.validate_fields_types() 

1006 

1007# maybe_valid_2 = ValidationContainer(positive_int=42, email=["test", "@", "example", ".com"]) 

1008# assert maybe_valid_2.validate_fields_types() 

1009 

1010# # Invalid positive_int 

1011# with pytest.raises(ValueError): 

1012# ValidationContainer(positive_int=-1, email="test@example.com") 

1013 

1014# # Invalid email 

1015# with pytest.raises(ValueError): 

1016# ValidationContainer(positive_int=42, email="invalid") 

1017 

1018 

1019def test_init_true_serialize_false(): 

1020 with pytest.raises(ValueError): 

1021 

1022 @serializable_dataclass 

1023 class MetadataContainer(SerializableDataclass): 

1024 """Test field metadata and options""" 

1025 

1026 hidden: str = serializable_field(serialize=False, init=True) 

1027 readonly: int = serializable_field(init=True, frozen=True) 

1028 computed: float = serializable_field(init=False, serialize=True) 

1029 

1030 def __post_init__(self): 

1031 object.__setattr__(self, "computed", self.readonly * 2.0) 

1032 

1033 

1034# Test property serialization 

1035@serializable_dataclass(properties_to_serialize=["full_name", "age_in_months"]) 

1036class PropertyContainer(SerializableDataclass): 

1037 """Test property serialization""" 

1038 

1039 first_name: str 

1040 last_name: str 

1041 age_years: int 

1042 

1043 @property 

1044 def full_name(self) -> str: 

1045 return f"{self.first_name} {self.last_name}" 

1046 

1047 @property 

1048 def age_in_months(self) -> int: 

1049 return self.age_years * 12 

1050 

1051 

1052def test_property_serialization(): 

1053 """Test serialization of properties""" 

1054 instance = PropertyContainer(first_name="John", last_name="Doe", age_years=30) 

1055 

1056 serialized = instance.serialize() 

1057 assert serialized["full_name"] == "John Doe" 

1058 assert serialized["age_in_months"] == 360 

1059 

1060 loaded = PropertyContainer.load(serialized) 

1061 assert loaded == instance 

1062 

1063 

1064# TODO: this would be nice to fix, but not a massive issue 

1065@pytest.mark.skip(reason="Not implemented yet") 

1066def test_edge_cases(): 

1067 """Test a sdc containing instances of itself""" 

1068 

1069 @serializable_dataclass 

1070 class EdgeCaseContainer(SerializableDataclass): 

1071 """Test edge cases and corner cases""" 

1072 

1073 empty_list: List[Any] = serializable_field(default_factory=list) 

1074 optional_value: Optional[int] = serializable_field(default=None) 

1075 union_field: Union[str, int, None] = serializable_field(default=None) 

1076 recursive_ref: Optional["EdgeCaseContainer"] = serializable_field(default=None) 

1077 

1078 # Test recursive structure 

1079 nested = EdgeCaseContainer() 

1080 instance = EdgeCaseContainer(recursive_ref=nested) 

1081 

1082 serialized = instance.serialize() 

1083 loaded = EdgeCaseContainer.load(serialized) 

1084 assert loaded == instance 

1085 

1086 # Test empty/None handling 

1087 empty = EdgeCaseContainer() 

1088 assert empty.empty_list == [] 

1089 assert empty.optional_value is None 

1090 assert empty.union_field is None 

1091 

1092 # Test union field with different types 

1093 instance.union_field = "string" 

1094 serialized = instance.serialize() 

1095 loaded = EdgeCaseContainer.load(serialized) 

1096 assert loaded.union_field == "string" 

1097 

1098 instance.union_field = 42 

1099 serialized = instance.serialize() 

1100 loaded = EdgeCaseContainer.load(serialized) 

1101 assert loaded.union_field == 42 

1102 

1103 

1104# Test error handling for malformed data 

1105def test_error_handling(): 

1106 """Test error handling for malformed data""" 

1107 # Missing required field 

1108 with pytest.raises(TypeError): 

1109 BaseClass.load({}) 

1110 

1111 x = BaseClass(base_field=42, shared_field="invalid") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 

1112 assert not x.validate_fields_types() 

1113 

1114 with pytest.raises(FieldTypeMismatchError): 

1115 BaseClass.load( 

1116 { 

1117 "base_field": 42, # Should be str 

1118 "shared_field": "invalid", # Should be int 

1119 } 

1120 ) 

1121 

1122 # Invalid format string 

1123 # with pytest.raises(ValueError): 

1124 # BaseClass.load({ 

1125 # _FORMAT_KEY: "InvalidClass(SerializableDataclass)", 

1126 # "base_field": "test", 

1127 # "shared_field": 0 

1128 # }) 

1129 

1130 

1131# Test for memory leaks and cyclic references 

1132# TODO: make .serialize() fail on cyclic references! see https://github.com/mivanit/muutils/issues/62 

1133@pytest.mark.skip(reason="Not implemented yet") 

1134def test_cyclic_references(): 

1135 """Test handling of cyclic references""" 

1136 

1137 @serializable_dataclass 

1138 class Node(SerializableDataclass): 

1139 value: str 

1140 next: Optional["Node"] = serializable_field(default=None) 

1141 

1142 # Create a cycle 

1143 node1 = Node(value="one") 

1144 node2 = Node(value="two") 

1145 node1.next = node2 

1146 node2.next = node1 

1147 

1148 # Ensure we can serialize without infinite recursion 

1149 serialized = node1.serialize() 

1150 loaded = Node.load(serialized) 

1151 assert loaded.value == "one" 

1152 # TODO: idk why we type ignore here 

1153 assert loaded.next.value == "two" # type: ignore[union-attr] # pyright: ignore[reportOptionalMemberAccess]