Coverage for tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py: 88%

489 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from __future__ import annotations 

2 

3from copy import deepcopy 

4import typing 

5from typing import Any, Dict, Generic, List, Optional, TypeVar, Union 

6 

7import pytest 

8 

9from muutils.errormode import ErrorMode 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from muutils.json_serialize.serializable_dataclass import ( 

17 FieldIsNotInitOrSerializeWarning, 

18 FieldTypeMismatchError, 

19) 

20from muutils.json_serialize.util import _FORMAT_KEY 

21 

22# pylint: disable=missing-class-docstring, unused-variable 

23 

24 

25@serializable_dataclass 

26class BasicAutofields(SerializableDataclass): 

27 a: str 

28 b: int 

29 c: typing.List[int] 

30 

31 

32def test_basic_auto_fields(): 

33 data = dict(a="hello", b=42, c=[1, 2, 3]) 

34 instance = BasicAutofields(**data) # type: ignore[arg-type] 

35 data_with_format = data.copy() 

36 data_with_format[_FORMAT_KEY] = "BasicAutofields(SerializableDataclass)" 

37 assert instance.serialize() == data_with_format 

38 assert instance == instance 

39 assert instance.diff(instance) == {} 

40 

41 

42def test_basic_diff(): 

43 instance_1 = BasicAutofields(a="hello", b=42, c=[1, 2, 3]) 

44 instance_2 = BasicAutofields(a="goodbye", b=42, c=[1, 2, 3]) 

45 instance_3 = BasicAutofields(a="hello", b=-1, c=[1, 2, 3]) 

46 instance_4 = BasicAutofields(a="hello", b=-1, c=[42]) 

47 

48 assert instance_1.diff(instance_2) == {"a": {"self": "hello", "other": "goodbye"}} 

49 assert instance_1.diff(instance_3) == {"b": {"self": 42, "other": -1}} 

50 assert instance_1.diff(instance_4) == { 

51 "b": {"self": 42, "other": -1}, 

52 "c": {"self": [1, 2, 3], "other": [42]}, 

53 } 

54 assert instance_1.diff(instance_1) == {} 

55 assert instance_2.diff(instance_3) == { 

56 "a": {"self": "goodbye", "other": "hello"}, 

57 "b": {"self": 42, "other": -1}, 

58 } 

59 

60 

61@serializable_dataclass 

62class SimpleFields(SerializableDataclass): 

63 d: str 

64 e: int = 42 

65 f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821 

66 

67 

68@serializable_dataclass 

69class FieldOptions(SerializableDataclass): 

70 a: str = serializable_field() 

71 b: str = serializable_field() 

72 c: str = serializable_field(init=False, serialize=False, repr=False, compare=False) 

73 d: str = serializable_field( 

74 serialization_fn=lambda x: x.upper(), loading_fn=lambda x: x["d"].lower() 

75 ) 

76 

77 

78@serializable_dataclass(properties_to_serialize=["full_name"]) 

79class WithProperty(SerializableDataclass): 

80 first_name: str 

81 last_name: str 

82 

83 @property 

84 def full_name(self) -> str: 

85 return f"{self.first_name} {self.last_name}" 

86 

87 

88class Child(FieldOptions, WithProperty): 

89 pass 

90 

91 

92@pytest.fixture 

93def simple_fields_instance(): 

94 return SimpleFields(d="hello", e=42, f=[1, 2, 3]) 

95 

96 

97@pytest.fixture 

98def field_options_instance(): 

99 return FieldOptions(a="hello", b="world", d="case") 

100 

101 

102@pytest.fixture 

103def with_property_instance(): 

104 return WithProperty(first_name="John", last_name="Doe") 

105 

106 

107def test_simple_fields_serialization(simple_fields_instance): 

108 serialized = simple_fields_instance.serialize() 

109 assert serialized == { 

110 "d": "hello", 

111 "e": 42, 

112 "f": [1, 2, 3], 

113 _FORMAT_KEY: "SimpleFields(SerializableDataclass)", 

114 } 

115 

116 

117def test_simple_fields_loading(simple_fields_instance): 

118 serialized = simple_fields_instance.serialize() 

119 

120 loaded = SimpleFields.load(serialized) 

121 

122 assert loaded == simple_fields_instance 

123 assert loaded.diff(simple_fields_instance) == {} 

124 assert simple_fields_instance.diff(loaded) == {} 

125 

126 

127def test_field_options_serialization(field_options_instance): 

128 serialized = field_options_instance.serialize() 

129 assert serialized == { 

130 "a": "hello", 

131 "b": "world", 

132 "d": "CASE", 

133 _FORMAT_KEY: "FieldOptions(SerializableDataclass)", 

134 } 

135 

136 

137def test_field_options_loading(field_options_instance): 

138 # ignore a `FieldIsNotInitOrSerializeWarning` 

139 serialized = field_options_instance.serialize() 

140 with pytest.warns(FieldIsNotInitOrSerializeWarning): 

141 loaded = FieldOptions.load(serialized) 

142 assert loaded == field_options_instance 

143 

144 

145def test_with_property_serialization(with_property_instance): 

146 serialized = with_property_instance.serialize() 

147 assert serialized == { 

148 "first_name": "John", 

149 "last_name": "Doe", 

150 "full_name": "John Doe", 

151 _FORMAT_KEY: "WithProperty(SerializableDataclass)", 

152 } 

153 

154 

155def test_with_property_loading(with_property_instance): 

156 serialized = with_property_instance.serialize() 

157 loaded = WithProperty.load(serialized) 

158 assert loaded == with_property_instance 

159 

160 

161@serializable_dataclass 

162class Address(SerializableDataclass): 

163 street: str 

164 city: str 

165 zip_code: str 

166 

167 

168@serializable_dataclass 

169class Person(SerializableDataclass): 

170 name: str 

171 age: int 

172 address: Address 

173 

174 

175@pytest.fixture 

176def address_instance(): 

177 return Address(street="123 Main St", city="New York", zip_code="10001") 

178 

179 

180@pytest.fixture 

181def person_instance(address_instance): 

182 return Person(name="John Doe", age=30, address=address_instance) 

183 

184 

185def test_nested_serialization(person_instance): 

186 serialized = person_instance.serialize() 

187 expected_ser = { 

188 "name": "John Doe", 

189 "age": 30, 

190 "address": { 

191 "street": "123 Main St", 

192 "city": "New York", 

193 "zip_code": "10001", 

194 _FORMAT_KEY: "Address(SerializableDataclass)", 

195 }, 

196 _FORMAT_KEY: "Person(SerializableDataclass)", 

197 } 

198 assert serialized == expected_ser 

199 

200 

201def test_nested_loading(person_instance): 

202 serialized = person_instance.serialize() 

203 loaded = Person.load(serialized) 

204 assert loaded == person_instance 

205 assert loaded.address == person_instance.address 

206 

207 

208def test_with_printing(): 

209 @serializable_dataclass(properties_to_serialize=["full_name"]) 

210 class MyClass(SerializableDataclass): 

211 name: str 

212 age: int = serializable_field( 

213 serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1 

214 ) 

215 items: list = serializable_field(default_factory=list) 

216 

217 @property 

218 def full_name(self) -> str: 

219 return f"{self.name} Doe" 

220 

221 # Usage 

222 my_instance = MyClass(name="John", age=30, items=["apple", "banana"]) 

223 serialized_data = my_instance.serialize() 

224 print(serialized_data) 

225 

226 loaded_instance = MyClass.load(serialized_data) 

227 print(loaded_instance) 

228 

229 

230def test_simple_class_serialization(): 

231 @serializable_dataclass 

232 class SimpleClass(SerializableDataclass): 

233 a: int 

234 b: str 

235 

236 simple = SimpleClass(a=42, b="hello") 

237 serialized = simple.serialize() 

238 assert serialized == { 

239 "a": 42, 

240 "b": "hello", 

241 _FORMAT_KEY: "SimpleClass(SerializableDataclass)", 

242 } 

243 

244 loaded = SimpleClass.load(serialized) 

245 assert loaded == simple 

246 

247 

248def test_error_when_init_and_not_serialize(): 

249 with pytest.raises(ValueError): 

250 

251 @serializable_dataclass 

252 class SimpleClass(SerializableDataclass): 

253 a: int = serializable_field(init=True, serialize=False) 

254 

255 

256def test_person_serialization(): 

257 @serializable_dataclass(properties_to_serialize=["full_name"]) 

258 class FullPerson(SerializableDataclass): 

259 name: str = serializable_field() 

260 age: int = serializable_field(default=-1) 

261 items: typing.List[str] = serializable_field(default_factory=list) 

262 

263 @property 

264 def full_name(self) -> str: 

265 return f"{self.name} Doe" 

266 

267 person = FullPerson(name="John", items=["apple", "banana"]) 

268 serialized = person.serialize() 

269 expected_ser = { 

270 "name": "John", 

271 "age": -1, 

272 "items": ["apple", "banana"], 

273 "full_name": "John Doe", 

274 _FORMAT_KEY: "FullPerson(SerializableDataclass)", 

275 } 

276 assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" 

277 

278 loaded = FullPerson.load(serialized) 

279 

280 assert loaded == person 

281 

282 

283def test_custom_serialization(): 

284 @serializable_dataclass 

285 class CustomSerialization(SerializableDataclass): 

286 data: Any = serializable_field( 

287 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["data"] // 2 

288 ) 

289 

290 custom = CustomSerialization(data=5) 

291 serialized = custom.serialize() 

292 assert serialized == { 

293 "data": 10, 

294 _FORMAT_KEY: "CustomSerialization(SerializableDataclass)", 

295 } 

296 

297 loaded = CustomSerialization.load(serialized) 

298 assert loaded == custom 

299 

300 

301@serializable_dataclass 

302class Nested_with_Container(SerializableDataclass): 

303 val_int: int 

304 val_str: str 

305 val_list: typing.List[BasicAutofields] = serializable_field( 

306 default_factory=list, 

307 serialization_fn=lambda x: [y.serialize() for y in x], 

308 loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]], 

309 ) 

310 

311 

312def test_nested_with_container(): 

313 instance = Nested_with_Container( 

314 val_int=42, 

315 val_str="hello", 

316 val_list=[ 

317 BasicAutofields(a="a", b=1, c=[1, 2, 3]), 

318 BasicAutofields(a="b", b=2, c=[4, 5, 6]), 

319 ], 

320 ) 

321 

322 serialized = instance.serialize() 

323 expected_ser = { 

324 "val_int": 42, 

325 "val_str": "hello", 

326 "val_list": [ 

327 { 

328 "a": "a", 

329 "b": 1, 

330 "c": [1, 2, 3], 

331 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)", 

332 }, 

333 { 

334 "a": "b", 

335 "b": 2, 

336 "c": [4, 5, 6], 

337 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)", 

338 }, 

339 ], 

340 _FORMAT_KEY: "Nested_with_Container(SerializableDataclass)", 

341 } 

342 

343 assert serialized == expected_ser 

344 

345 loaded = Nested_with_Container.load(serialized) 

346 

347 assert loaded == instance 

348 

349 

350class Custom_class_with_serialization: 

351 """custom class which doesnt inherit but does serialize""" 

352 

353 def __init__(self, a: int, b: str): 

354 self.a: int = a 

355 self.b: str = b 

356 

357 def serialize(self): 

358 return {"a": self.a, "b": self.b} 

359 

360 @classmethod 

361 def load(cls, data): 

362 return cls(data["a"], data["b"]) 

363 

364 def __eq__(self, other): 

365 return (self.a == other.a) and (self.b == other.b) 

366 

367 

368@serializable_dataclass 

369class nested_custom(SerializableDataclass): 

370 value: float 

371 data1: Custom_class_with_serialization 

372 

373 

374def test_nested_custom(recwarn): # this will send some warnings but whatever 

375 instance = nested_custom( 

376 value=42.0, data1=Custom_class_with_serialization(1, "hello") 

377 ) 

378 serialized = instance.serialize() 

379 expected_ser = { 

380 "value": 42.0, 

381 "data1": {"a": 1, "b": "hello"}, 

382 _FORMAT_KEY: "nested_custom(SerializableDataclass)", 

383 } 

384 assert serialized == expected_ser 

385 loaded = nested_custom.load(serialized) 

386 assert loaded == instance 

387 

388 

389def test_deserialize_fn(): 

390 @serializable_dataclass 

391 class DeserializeFn(SerializableDataclass): 

392 data: int = serializable_field( 

393 serialization_fn=lambda x: str(x), 

394 deserialize_fn=lambda x: int(x), 

395 ) 

396 

397 instance = DeserializeFn(data=5) 

398 serialized = instance.serialize() 

399 assert serialized == { 

400 "data": "5", 

401 _FORMAT_KEY: "DeserializeFn(SerializableDataclass)", 

402 } 

403 

404 loaded = DeserializeFn.load(serialized) 

405 assert loaded == instance 

406 assert loaded.data == 5 

407 

408 

409@serializable_dataclass 

410class DictContainer(SerializableDataclass): 

411 """Test class containing a dictionary field""" 

412 

413 simple_dict: Dict[str, int] 

414 nested_dict: Dict[str, Dict[str, int]] = serializable_field(default_factory=dict) 

415 optional_dict: Dict[str, str] = serializable_field(default_factory=dict) 

416 

417 

418def test_dict_serialization(): 

419 """Test serialization of dictionaries within SerializableDataclass""" 

420 data = DictContainer( 

421 simple_dict={"a": 1, "b": 2}, 

422 nested_dict={"x": {"y": 3, "z": 4}}, 

423 optional_dict={"hello": "world"}, 

424 ) 

425 

426 serialized = data.serialize() 

427 expected = { 

428 _FORMAT_KEY: "DictContainer(SerializableDataclass)", 

429 "simple_dict": {"a": 1, "b": 2}, 

430 "nested_dict": {"x": {"y": 3, "z": 4}}, 

431 "optional_dict": {"hello": "world"}, 

432 } 

433 

434 assert serialized == expected 

435 

436 

437def test_dict_loading(): 

438 """Test loading dictionaries into SerializableDataclass""" 

439 original_data = { 

440 _FORMAT_KEY: "DictContainer(SerializableDataclass)", 

441 "simple_dict": {"a": 1, "b": 2}, 

442 "nested_dict": {"x": {"y": 3, "z": 4}}, 

443 "optional_dict": {"hello": "world"}, 

444 } 

445 

446 loaded = DictContainer.load(original_data) 

447 assert loaded.simple_dict == {"a": 1, "b": 2} 

448 assert loaded.nested_dict == {"x": {"y": 3, "z": 4}} 

449 assert loaded.optional_dict == {"hello": "world"} 

450 

451 

452def test_dict_equality(): 

453 """Test equality comparison of dictionaries within SerializableDataclass""" 

454 instance1 = DictContainer( 

455 simple_dict={"a": 1, "b": 2}, 

456 nested_dict={"x": {"y": 3, "z": 4}}, 

457 optional_dict={"hello": "world"}, 

458 ) 

459 

460 instance2 = DictContainer( 

461 simple_dict={"a": 1, "b": 2}, 

462 nested_dict={"x": {"y": 3, "z": 4}}, 

463 optional_dict={"hello": "world"}, 

464 ) 

465 

466 instance3 = DictContainer( 

467 simple_dict={"a": 1, "b": 3}, # Different value 

468 nested_dict={"x": {"y": 3, "z": 4}}, 

469 optional_dict={"hello": "world"}, 

470 ) 

471 

472 assert instance1 == instance2 

473 assert instance1 != instance3 

474 assert instance2 != instance3 

475 

476 

477def test_dict_diff(): 

478 """Test diff functionality with dictionaries""" 

479 instance1 = DictContainer( 

480 simple_dict={"a": 1, "b": 2}, 

481 nested_dict={"x": {"y": 3, "z": 4}}, 

482 optional_dict={"hello": "world"}, 

483 ) 

484 

485 # Different simple_dict value 

486 instance2 = DictContainer( 

487 simple_dict={"a": 1, "b": 3}, 

488 nested_dict={"x": {"y": 3, "z": 4}}, 

489 optional_dict={"hello": "world"}, 

490 ) 

491 

492 # Different nested_dict value 

493 instance3 = DictContainer( 

494 simple_dict={"a": 1, "b": 2}, 

495 nested_dict={"x": {"y": 3, "z": 5}}, 

496 optional_dict={"hello": "world"}, 

497 ) 

498 

499 # Different optional_dict value 

500 instance4 = DictContainer( 

501 simple_dict={"a": 1, "b": 2}, 

502 nested_dict={"x": {"y": 3, "z": 4}}, 

503 optional_dict={"hello": "python"}, 

504 ) 

505 

506 # Test diff with simple_dict changes 

507 diff1 = instance1.diff(instance2) 

508 assert diff1 == { 

509 "simple_dict": {"self": {"a": 1, "b": 2}, "other": {"a": 1, "b": 3}} 

510 } 

511 

512 # Test diff with nested_dict changes 

513 diff2 = instance1.diff(instance3) 

514 assert diff2 == { 

515 "nested_dict": { 

516 "self": {"x": {"y": 3, "z": 4}}, 

517 "other": {"x": {"y": 3, "z": 5}}, 

518 } 

519 } 

520 

521 # Test diff with optional_dict changes 

522 diff3 = instance1.diff(instance4) 

523 assert diff3 == { 

524 "optional_dict": {"self": {"hello": "world"}, "other": {"hello": "python"}} 

525 } 

526 

527 # Test no diff when comparing identical instances 

528 assert instance1.diff(instance1) == {} 

529 

530 

531@serializable_dataclass 

532class ComplexDictContainer(SerializableDataclass): 

533 """Test class with more complex dictionary structures""" 

534 

535 mixed_dict: Dict[str, Any] 

536 list_dict: Dict[str, typing.List[int]] 

537 multi_nested: Dict[str, Dict[str, Dict[str, int]]] 

538 

539 

540def test_complex_dict_serialization(): 

541 """Test serialization of more complex dictionary structures""" 

542 data = ComplexDictContainer( 

543 mixed_dict={"str": "hello", "int": 42, "list": [1, 2, 3]}, 

544 list_dict={"a": [1, 2, 3], "b": [4, 5, 6]}, 

545 multi_nested={"x": {"y": {"z": 1, "w": 2}, "v": {"u": 3, "t": 4}}}, 

546 ) 

547 

548 serialized = data.serialize() 

549 loaded = ComplexDictContainer.load(serialized) 

550 assert loaded == data 

551 assert loaded.diff(data) == {} 

552 

553 

554def test_empty_dicts(): 

555 """Test handling of empty dictionaries""" 

556 data = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

557 

558 serialized = data.serialize() 

559 loaded = DictContainer.load(serialized) 

560 assert loaded == data 

561 assert loaded.diff(data) == {} 

562 

563 # Test equality with another empty instance 

564 another_empty = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

565 assert data == another_empty 

566 

567 

568# Test invalid dictionary type validation 

569@serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

570class StrictDictContainer(SerializableDataclass): 

571 """Test class with strict dictionary typing""" 

572 

573 int_dict: Dict[str, int] 

574 str_dict: Dict[str, str] 

575 float_dict: Dict[str, float] 

576 

577 

578# TODO: figure this out 

579@pytest.mark.skip(reason="dict type validation doesnt seem to work") 

580def test_dict_type_validation(): 

581 """Test type validation for dictionary values""" 

582 # Valid case 

583 valid = StrictDictContainer( 

584 int_dict={"a": 1, "b": 2}, 

585 str_dict={"x": "hello", "y": "world"}, 

586 float_dict={"m": 1.0, "n": 2.5}, 

587 ) 

588 assert valid.validate_fields_types() 

589 

590 # Invalid int_dict 

591 with pytest.raises(FieldTypeMismatchError): 

592 StrictDictContainer( 

593 int_dict={"a": "not an int"}, # type: ignore[dict-item] 

594 str_dict={"x": "hello"}, 

595 float_dict={"m": 1.0}, 

596 ) 

597 

598 # Invalid str_dict 

599 with pytest.raises(FieldTypeMismatchError): 

600 StrictDictContainer( 

601 int_dict={"a": 1}, 

602 str_dict={"x": 123}, # type: ignore[dict-item] 

603 float_dict={"m": 1.0}, 

604 ) 

605 

606 

607# Test dictionary with optional values 

608@serializable_dataclass 

609class OptionalDictContainer(SerializableDataclass): 

610 """Test class with optional dictionary values""" 

611 

612 optional_values: Dict[str, Optional[int]] 

613 union_values: Dict[str, Union[int, str]] 

614 nullable_dict: Optional[Dict[str, int]] = None 

615 

616 

617def test_optional_dict_values(): 

618 """Test dictionaries with optional/union values""" 

619 instance = OptionalDictContainer( 

620 optional_values={"a": 1, "b": None, "c": 3}, 

621 union_values={"x": 1, "y": "string", "z": 42}, 

622 nullable_dict={"m": 1, "n": 2}, 

623 ) 

624 

625 serialized = instance.serialize() 

626 loaded = OptionalDictContainer.load(serialized) 

627 assert loaded == instance 

628 

629 # Test with None dict 

630 instance2 = OptionalDictContainer( 

631 optional_values={"a": None, "b": None}, 

632 union_values={"x": "all strings", "y": "here"}, 

633 nullable_dict=None, 

634 ) 

635 

636 serialized2 = instance2.serialize() 

637 loaded2 = OptionalDictContainer.load(serialized2) 

638 assert loaded2 == instance2 

639 

640 

641# Test dictionary mutation 

642def test_dict_mutation(): 

643 """Test behavior when mutating dictionary contents""" 

644 instance1 = DictContainer( 

645 simple_dict={"a": 1, "b": 2}, 

646 nested_dict={"x": {"y": 3}}, 

647 optional_dict={"hello": "world"}, 

648 ) 

649 

650 instance2 = deepcopy(instance1) 

651 

652 # Mutate dictionary in instance1 

653 instance1.simple_dict["c"] = 3 

654 instance1.nested_dict["x"]["z"] = 4 

655 instance1.optional_dict["new"] = "value" 

656 

657 # Verify instance2 was not affected 

658 assert instance2.simple_dict == {"a": 1, "b": 2} 

659 assert instance2.nested_dict == {"x": {"y": 3}} 

660 assert instance2.optional_dict == {"hello": "world"} 

661 

662 # Verify diff shows the changes 

663 diff = instance2.diff(instance1) 

664 assert "simple_dict" in diff 

665 assert "nested_dict" in diff 

666 assert "optional_dict" in diff 

667 

668 

669# Test dictionary key types 

670@serializable_dataclass 

671class IntKeyDictContainer(SerializableDataclass): 

672 """Test class with non-string dictionary keys""" 

673 

674 int_keys: Dict[int, str] = serializable_field( 

675 serialization_fn=lambda x: {str(k): v for k, v in x.items()}, 

676 loading_fn=lambda x: {int(k): v for k, v in x["int_keys"].items()}, 

677 ) 

678 

679 

680def test_non_string_dict_keys(): 

681 """Test handling of dictionaries with non-string keys""" 

682 instance = IntKeyDictContainer(int_keys={1: "one", 2: "two", 3: "three"}) 

683 

684 serialized = instance.serialize() 

685 # Keys should be converted to strings in serialized form 

686 assert all(isinstance(k, str) for k in serialized["int_keys"].keys()) 

687 

688 loaded = IntKeyDictContainer.load(serialized) 

689 # Keys should be integers again after loading 

690 assert all(isinstance(k, int) for k in loaded.int_keys.keys()) 

691 assert loaded == instance 

692 

693 

694@serializable_dataclass 

695class RecursiveDictContainer(SerializableDataclass): 

696 """Test class with recursively defined dictionary type""" 

697 

698 data: Dict[str, Any] 

699 

700 

701def test_recursive_dict_structure(): 

702 """Test handling of recursively nested dictionaries""" 

703 deep_dict = { 

704 "level1": { 

705 "level2": {"level3": {"value": 42, "list": [1, 2, {"nested": "value"}]}} 

706 } 

707 } 

708 

709 instance = RecursiveDictContainer(data=deep_dict) 

710 serialized = instance.serialize() 

711 loaded = RecursiveDictContainer.load(serialized) 

712 

713 assert loaded == instance 

714 assert loaded.data == deep_dict 

715 

716 

717# need to define this outside, otherwise the validator cant see it? 

718class CustomSerializable: 

719 def __init__(self, value): 

720 self.value: Union[str, int] = value 

721 

722 def serialize(self): 

723 return {"value": self.value} 

724 

725 @classmethod 

726 def load(cls, data): 

727 return cls(data["value"]) 

728 

729 def __eq__(self, other): 

730 return isinstance(other, CustomSerializable) and self.value == other.value 

731 

732 

733def test_dict_with_custom_objects(): 

734 """Test dictionaries containing custom objects that implement serialize/load""" 

735 

736 @serializable_dataclass 

737 class CustomObjectDict(SerializableDataclass): 

738 data: Dict[str, CustomSerializable] = serializable_field() 

739 

740 instance: CustomObjectDict = CustomObjectDict( 

741 data={"a": CustomSerializable(42), "b": CustomSerializable("hello")} 

742 ) 

743 

744 assert isinstance(instance, CustomObjectDict) 

745 assert isinstance(instance.data, dict) 

746 assert isinstance(instance.data["a"], CustomSerializable) 

747 assert isinstance(instance.data["a"].value, int) 

748 assert isinstance(instance.data["b"], CustomSerializable) 

749 assert isinstance(instance.data["b"].value, str) 

750 

751 serialized = instance.serialize() 

752 loaded = CustomObjectDict.load(serialized) 

753 assert loaded == instance 

754 

755 

756def test_empty_optional_dicts(): 

757 """Test handling of None vs empty dict in optional dictionary fields""" 

758 

759 @serializable_dataclass 

760 class OptionalDictFields(SerializableDataclass): 

761 required_dict: Dict[str, int] 

762 optional_dict: Optional[Dict[str, int]] = None 

763 default_empty: Dict[str, int] = serializable_field(default_factory=dict) 

764 

765 # Test with None 

766 instance1 = OptionalDictFields(required_dict={"a": 1}, optional_dict=None) 

767 

768 # Test with empty dict 

769 instance2 = OptionalDictFields(required_dict={"a": 1}, optional_dict={}) 

770 

771 serialized1 = instance1.serialize() 

772 serialized2 = instance2.serialize() 

773 

774 loaded1 = OptionalDictFields.load(serialized1) 

775 loaded2 = OptionalDictFields.load(serialized2) 

776 

777 assert loaded1.optional_dict is None 

778 assert loaded2.optional_dict == {} 

779 assert loaded1.default_empty == {} 

780 assert loaded2.default_empty == {} 

781 

782 

783# Test inheritance hierarchies 

784@serializable_dataclass( 

785 on_typecheck_error=ErrorMode.EXCEPT, on_typecheck_mismatch=ErrorMode.EXCEPT 

786) 

787class BaseClass(SerializableDataclass): 

788 """Base class for testing inheritance""" 

789 

790 base_field: str 

791 shared_field: int = serializable_field(default=0) 

792 

793 

794@serializable_dataclass 

795class ChildClass(BaseClass): 

796 """Child class inheriting from BaseClass""" 

797 

798 child_field: float = serializable_field(default=0.1) 

799 shared_field: int = serializable_field(default=1) # Override base class field 

800 

801 

802@serializable_dataclass 

803class GrandchildClass(ChildClass): 

804 """Grandchild class for deep inheritance testing""" 

805 

806 grandchild_field: bool = serializable_field(default=True) 

807 

808 

809def test_inheritance(): 

810 """Test inheritance behavior of serializable dataclasses""" 

811 instance = GrandchildClass( 

812 base_field="base", shared_field=42, child_field=3.14, grandchild_field=True 

813 ) 

814 

815 serialized = instance.serialize() 

816 assert serialized["base_field"] == "base" 

817 assert serialized["shared_field"] == 42 

818 assert serialized["child_field"] == 3.14 

819 assert serialized["grandchild_field"] is True 

820 

821 loaded = GrandchildClass.load(serialized) 

822 assert loaded == instance 

823 

824 # Test that we can load as parent class 

825 base_loaded = BaseClass.load({"base_field": "test", "shared_field": 1}) 

826 assert isinstance(base_loaded, BaseClass) 

827 assert not isinstance(base_loaded, ChildClass) 

828 

829 

830@pytest.mark.skip( 

831 reason="Not implemented yet, generic types not supported and throw a `TypeHintNotImplementedError`" 

832) 

833def test_generic_types(): 

834 """Test handling of generic type parameters""" 

835 

836 T = TypeVar("T") 

837 

838 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

839 class GenericContainer(SerializableDataclass, Generic[T]): 

840 """Test generic type parameters""" 

841 

842 value: T 

843 values: List[T] 

844 

845 # Test with int 

846 int_container = GenericContainer[int](value=42, values=[1, 2, 3]) 

847 serialized = int_container.serialize() 

848 loaded = GenericContainer[int].load(serialized) 

849 assert loaded == int_container 

850 

851 # Test with str 

852 str_container = GenericContainer[str](value="hello", values=["a", "b", "c"]) 

853 serialized2 = str_container.serialize() 

854 loaded2 = GenericContainer[str].load(serialized2) 

855 assert loaded2 == str_container 

856 

857 

858# Test custom serialization/deserialization 

859class CustomObject: 

860 def __init__(self, value): 

861 self.value = value 

862 

863 def __eq__(self, other): 

864 return isinstance(other, CustomObject) and self.value == other.value 

865 

866 

867@serializable_dataclass 

868class CustomSerializationContainer(SerializableDataclass): 

869 """Test custom serialization functions""" 

870 

871 custom_obj: CustomObject = serializable_field( 

872 serialization_fn=lambda x: x.value, 

873 loading_fn=lambda x: CustomObject(x["custom_obj"]), 

874 ) 

875 transform_field: int = serializable_field( 

876 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["transform_field"] // 2 

877 ) 

878 

879 

880def test_custom_serialization_2(): 

881 """Test custom serialization and loading functions""" 

882 instance = CustomSerializationContainer( 

883 custom_obj=CustomObject(42), transform_field=10 

884 ) 

885 

886 serialized = instance.serialize() 

887 assert serialized["custom_obj"] == 42 

888 assert serialized["transform_field"] == 20 

889 

890 loaded = CustomSerializationContainer.load(serialized) 

891 assert loaded == instance 

892 assert loaded.transform_field == 10 

893 

894 

895# @pytest.mark.skip(reason="Not implemented yet, waiting on `custom_value_check_fn`") 

896# def test_value_validation(): 

897# """Test field validation""" 

898# @serializable_dataclass 

899# class ValidationContainer(SerializableDataclass): 

900# """Test validation and error handling""" 

901# positive_int: int = serializable_field( 

902# custom_value_check_fn=lambda x: x > 0 

903# ) 

904# email: str = serializable_field( 

905# custom_value_check_fn=lambda x: '@' in x 

906# ) 

907 

908# # Valid case 

909# valid = ValidationContainer(positive_int=42, email="test@example.com") 

910# assert valid.validate_fields_types() 

911 

912# # what will this do? 

913# maybe_valid = ValidationContainer(positive_int=4.2, email="test@example.com") 

914# assert maybe_valid.validate_fields_types() 

915 

916# maybe_valid_2 = ValidationContainer(positive_int=42, email=["test", "@", "example", ".com"]) 

917# assert maybe_valid_2.validate_fields_types() 

918 

919# # Invalid positive_int 

920# with pytest.raises(ValueError): 

921# ValidationContainer(positive_int=-1, email="test@example.com") 

922 

923# # Invalid email 

924# with pytest.raises(ValueError): 

925# ValidationContainer(positive_int=42, email="invalid") 

926 

927 

928def test_init_true_serialize_false(): 

929 with pytest.raises(ValueError): 

930 

931 @serializable_dataclass 

932 class MetadataContainer(SerializableDataclass): 

933 """Test field metadata and options""" 

934 

935 hidden: str = serializable_field(serialize=False, init=True) 

936 readonly: int = serializable_field(init=True, frozen=True) 

937 computed: float = serializable_field(init=False, serialize=True) 

938 

939 def __post_init__(self): 

940 object.__setattr__(self, "computed", self.readonly * 2.0) 

941 

942 

943# Test property serialization 

944@serializable_dataclass(properties_to_serialize=["full_name", "age_in_months"]) 

945class PropertyContainer(SerializableDataclass): 

946 """Test property serialization""" 

947 

948 first_name: str 

949 last_name: str 

950 age_years: int 

951 

952 @property 

953 def full_name(self) -> str: 

954 return f"{self.first_name} {self.last_name}" 

955 

956 @property 

957 def age_in_months(self) -> int: 

958 return self.age_years * 12 

959 

960 

961def test_property_serialization(): 

962 """Test serialization of properties""" 

963 instance = PropertyContainer(first_name="John", last_name="Doe", age_years=30) 

964 

965 serialized = instance.serialize() 

966 assert serialized["full_name"] == "John Doe" 

967 assert serialized["age_in_months"] == 360 

968 

969 loaded = PropertyContainer.load(serialized) 

970 assert loaded == instance 

971 

972 

973# TODO: this would be nice to fix, but not a massive issue 

974@pytest.mark.skip(reason="Not implemented yet") 

975def test_edge_cases(): 

976 """Test a sdc containing instances of itself""" 

977 

978 @serializable_dataclass 

979 class EdgeCaseContainer(SerializableDataclass): 

980 """Test edge cases and corner cases""" 

981 

982 empty_list: List[Any] = serializable_field(default_factory=list) 

983 optional_value: Optional[int] = serializable_field(default=None) 

984 union_field: Union[str, int, None] = serializable_field(default=None) 

985 recursive_ref: Optional["EdgeCaseContainer"] = serializable_field(default=None) 

986 

987 # Test recursive structure 

988 nested = EdgeCaseContainer() 

989 instance = EdgeCaseContainer(recursive_ref=nested) 

990 

991 serialized = instance.serialize() 

992 loaded = EdgeCaseContainer.load(serialized) 

993 assert loaded == instance 

994 

995 # Test empty/None handling 

996 empty = EdgeCaseContainer() 

997 assert empty.empty_list == [] 

998 assert empty.optional_value is None 

999 assert empty.union_field is None 

1000 

1001 # Test union field with different types 

1002 instance.union_field = "string" 

1003 serialized = instance.serialize() 

1004 loaded = EdgeCaseContainer.load(serialized) 

1005 assert loaded.union_field == "string" 

1006 

1007 instance.union_field = 42 

1008 serialized = instance.serialize() 

1009 loaded = EdgeCaseContainer.load(serialized) 

1010 assert loaded.union_field == 42 

1011 

1012 

1013# Test error handling for malformed data 

1014def test_error_handling(): 

1015 """Test error handling for malformed data""" 

1016 # Missing required field 

1017 with pytest.raises(TypeError): 

1018 BaseClass.load({}) 

1019 

1020 x = BaseClass(base_field=42, shared_field="invalid") # type: ignore[arg-type] 

1021 assert not x.validate_fields_types() 

1022 

1023 with pytest.raises(FieldTypeMismatchError): 

1024 BaseClass.load( 

1025 { 

1026 "base_field": 42, # Should be str 

1027 "shared_field": "invalid", # Should be int 

1028 } 

1029 ) 

1030 

1031 # Invalid format string 

1032 # with pytest.raises(ValueError): 

1033 # BaseClass.load({ 

1034 # _FORMAT_KEY: "InvalidClass(SerializableDataclass)", 

1035 # "base_field": "test", 

1036 # "shared_field": 0 

1037 # }) 

1038 

1039 

1040# Test for memory leaks and cyclic references 

1041# TODO: make .serialize() fail on cyclic references! see https://github.com/mivanit/muutils/issues/62 

1042@pytest.mark.skip(reason="Not implemented yet") 

1043def test_cyclic_references(): 

1044 """Test handling of cyclic references""" 

1045 

1046 @serializable_dataclass 

1047 class Node(SerializableDataclass): 

1048 value: str 

1049 next: Optional["Node"] = serializable_field(default=None) 

1050 

1051 # Create a cycle 

1052 node1 = Node("one") 

1053 node2 = Node("two") 

1054 node1.next = node2 

1055 node2.next = node1 

1056 

1057 # Ensure we can serialize without infinite recursion 

1058 serialized = node1.serialize() 

1059 loaded = Node.load(serialized) 

1060 assert loaded.value == "one" 

1061 # TODO: idk why we type ignore here 

1062 assert loaded.next.value == "two" # type: ignore[union-attr]