Coverage for tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py: 88%
489 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1from __future__ import annotations
3from copy import deepcopy
4import typing
5from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
7import pytest
9from muutils.errormode import ErrorMode
10from muutils.json_serialize import (
11 SerializableDataclass,
12 serializable_dataclass,
13 serializable_field,
14)
16from muutils.json_serialize.serializable_dataclass import (
17 FieldIsNotInitOrSerializeWarning,
18 FieldTypeMismatchError,
19)
20from muutils.json_serialize.util import _FORMAT_KEY
22# pylint: disable=missing-class-docstring, unused-variable
25@serializable_dataclass
26class BasicAutofields(SerializableDataclass):
27 a: str
28 b: int
29 c: typing.List[int]
32def test_basic_auto_fields():
33 data = dict(a="hello", b=42, c=[1, 2, 3])
34 instance = BasicAutofields(**data) # type: ignore[arg-type]
35 data_with_format = data.copy()
36 data_with_format[_FORMAT_KEY] = "BasicAutofields(SerializableDataclass)"
37 assert instance.serialize() == data_with_format
38 assert instance == instance
39 assert instance.diff(instance) == {}
42def test_basic_diff():
43 instance_1 = BasicAutofields(a="hello", b=42, c=[1, 2, 3])
44 instance_2 = BasicAutofields(a="goodbye", b=42, c=[1, 2, 3])
45 instance_3 = BasicAutofields(a="hello", b=-1, c=[1, 2, 3])
46 instance_4 = BasicAutofields(a="hello", b=-1, c=[42])
48 assert instance_1.diff(instance_2) == {"a": {"self": "hello", "other": "goodbye"}}
49 assert instance_1.diff(instance_3) == {"b": {"self": 42, "other": -1}}
50 assert instance_1.diff(instance_4) == {
51 "b": {"self": 42, "other": -1},
52 "c": {"self": [1, 2, 3], "other": [42]},
53 }
54 assert instance_1.diff(instance_1) == {}
55 assert instance_2.diff(instance_3) == {
56 "a": {"self": "goodbye", "other": "hello"},
57 "b": {"self": 42, "other": -1},
58 }
61@serializable_dataclass
62class SimpleFields(SerializableDataclass):
63 d: str
64 e: int = 42
65 f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821
68@serializable_dataclass
69class FieldOptions(SerializableDataclass):
70 a: str = serializable_field()
71 b: str = serializable_field()
72 c: str = serializable_field(init=False, serialize=False, repr=False, compare=False)
73 d: str = serializable_field(
74 serialization_fn=lambda x: x.upper(), loading_fn=lambda x: x["d"].lower()
75 )
78@serializable_dataclass(properties_to_serialize=["full_name"])
79class WithProperty(SerializableDataclass):
80 first_name: str
81 last_name: str
83 @property
84 def full_name(self) -> str:
85 return f"{self.first_name} {self.last_name}"
88class Child(FieldOptions, WithProperty):
89 pass
92@pytest.fixture
93def simple_fields_instance():
94 return SimpleFields(d="hello", e=42, f=[1, 2, 3])
97@pytest.fixture
98def field_options_instance():
99 return FieldOptions(a="hello", b="world", d="case")
102@pytest.fixture
103def with_property_instance():
104 return WithProperty(first_name="John", last_name="Doe")
107def test_simple_fields_serialization(simple_fields_instance):
108 serialized = simple_fields_instance.serialize()
109 assert serialized == {
110 "d": "hello",
111 "e": 42,
112 "f": [1, 2, 3],
113 _FORMAT_KEY: "SimpleFields(SerializableDataclass)",
114 }
117def test_simple_fields_loading(simple_fields_instance):
118 serialized = simple_fields_instance.serialize()
120 loaded = SimpleFields.load(serialized)
122 assert loaded == simple_fields_instance
123 assert loaded.diff(simple_fields_instance) == {}
124 assert simple_fields_instance.diff(loaded) == {}
127def test_field_options_serialization(field_options_instance):
128 serialized = field_options_instance.serialize()
129 assert serialized == {
130 "a": "hello",
131 "b": "world",
132 "d": "CASE",
133 _FORMAT_KEY: "FieldOptions(SerializableDataclass)",
134 }
137def test_field_options_loading(field_options_instance):
138 # ignore a `FieldIsNotInitOrSerializeWarning`
139 serialized = field_options_instance.serialize()
140 with pytest.warns(FieldIsNotInitOrSerializeWarning):
141 loaded = FieldOptions.load(serialized)
142 assert loaded == field_options_instance
145def test_with_property_serialization(with_property_instance):
146 serialized = with_property_instance.serialize()
147 assert serialized == {
148 "first_name": "John",
149 "last_name": "Doe",
150 "full_name": "John Doe",
151 _FORMAT_KEY: "WithProperty(SerializableDataclass)",
152 }
155def test_with_property_loading(with_property_instance):
156 serialized = with_property_instance.serialize()
157 loaded = WithProperty.load(serialized)
158 assert loaded == with_property_instance
161@serializable_dataclass
162class Address(SerializableDataclass):
163 street: str
164 city: str
165 zip_code: str
168@serializable_dataclass
169class Person(SerializableDataclass):
170 name: str
171 age: int
172 address: Address
175@pytest.fixture
176def address_instance():
177 return Address(street="123 Main St", city="New York", zip_code="10001")
180@pytest.fixture
181def person_instance(address_instance):
182 return Person(name="John Doe", age=30, address=address_instance)
185def test_nested_serialization(person_instance):
186 serialized = person_instance.serialize()
187 expected_ser = {
188 "name": "John Doe",
189 "age": 30,
190 "address": {
191 "street": "123 Main St",
192 "city": "New York",
193 "zip_code": "10001",
194 _FORMAT_KEY: "Address(SerializableDataclass)",
195 },
196 _FORMAT_KEY: "Person(SerializableDataclass)",
197 }
198 assert serialized == expected_ser
201def test_nested_loading(person_instance):
202 serialized = person_instance.serialize()
203 loaded = Person.load(serialized)
204 assert loaded == person_instance
205 assert loaded.address == person_instance.address
208def test_with_printing():
209 @serializable_dataclass(properties_to_serialize=["full_name"])
210 class MyClass(SerializableDataclass):
211 name: str
212 age: int = serializable_field(
213 serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1
214 )
215 items: list = serializable_field(default_factory=list)
217 @property
218 def full_name(self) -> str:
219 return f"{self.name} Doe"
221 # Usage
222 my_instance = MyClass(name="John", age=30, items=["apple", "banana"])
223 serialized_data = my_instance.serialize()
224 print(serialized_data)
226 loaded_instance = MyClass.load(serialized_data)
227 print(loaded_instance)
230def test_simple_class_serialization():
231 @serializable_dataclass
232 class SimpleClass(SerializableDataclass):
233 a: int
234 b: str
236 simple = SimpleClass(a=42, b="hello")
237 serialized = simple.serialize()
238 assert serialized == {
239 "a": 42,
240 "b": "hello",
241 _FORMAT_KEY: "SimpleClass(SerializableDataclass)",
242 }
244 loaded = SimpleClass.load(serialized)
245 assert loaded == simple
248def test_error_when_init_and_not_serialize():
249 with pytest.raises(ValueError):
251 @serializable_dataclass
252 class SimpleClass(SerializableDataclass):
253 a: int = serializable_field(init=True, serialize=False)
256def test_person_serialization():
257 @serializable_dataclass(properties_to_serialize=["full_name"])
258 class FullPerson(SerializableDataclass):
259 name: str = serializable_field()
260 age: int = serializable_field(default=-1)
261 items: typing.List[str] = serializable_field(default_factory=list)
263 @property
264 def full_name(self) -> str:
265 return f"{self.name} Doe"
267 person = FullPerson(name="John", items=["apple", "banana"])
268 serialized = person.serialize()
269 expected_ser = {
270 "name": "John",
271 "age": -1,
272 "items": ["apple", "banana"],
273 "full_name": "John Doe",
274 _FORMAT_KEY: "FullPerson(SerializableDataclass)",
275 }
276 assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}"
278 loaded = FullPerson.load(serialized)
280 assert loaded == person
283def test_custom_serialization():
284 @serializable_dataclass
285 class CustomSerialization(SerializableDataclass):
286 data: Any = serializable_field(
287 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["data"] // 2
288 )
290 custom = CustomSerialization(data=5)
291 serialized = custom.serialize()
292 assert serialized == {
293 "data": 10,
294 _FORMAT_KEY: "CustomSerialization(SerializableDataclass)",
295 }
297 loaded = CustomSerialization.load(serialized)
298 assert loaded == custom
301@serializable_dataclass
302class Nested_with_Container(SerializableDataclass):
303 val_int: int
304 val_str: str
305 val_list: typing.List[BasicAutofields] = serializable_field(
306 default_factory=list,
307 serialization_fn=lambda x: [y.serialize() for y in x],
308 loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]],
309 )
312def test_nested_with_container():
313 instance = Nested_with_Container(
314 val_int=42,
315 val_str="hello",
316 val_list=[
317 BasicAutofields(a="a", b=1, c=[1, 2, 3]),
318 BasicAutofields(a="b", b=2, c=[4, 5, 6]),
319 ],
320 )
322 serialized = instance.serialize()
323 expected_ser = {
324 "val_int": 42,
325 "val_str": "hello",
326 "val_list": [
327 {
328 "a": "a",
329 "b": 1,
330 "c": [1, 2, 3],
331 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)",
332 },
333 {
334 "a": "b",
335 "b": 2,
336 "c": [4, 5, 6],
337 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)",
338 },
339 ],
340 _FORMAT_KEY: "Nested_with_Container(SerializableDataclass)",
341 }
343 assert serialized == expected_ser
345 loaded = Nested_with_Container.load(serialized)
347 assert loaded == instance
350class Custom_class_with_serialization:
351 """custom class which doesnt inherit but does serialize"""
353 def __init__(self, a: int, b: str):
354 self.a: int = a
355 self.b: str = b
357 def serialize(self):
358 return {"a": self.a, "b": self.b}
360 @classmethod
361 def load(cls, data):
362 return cls(data["a"], data["b"])
364 def __eq__(self, other):
365 return (self.a == other.a) and (self.b == other.b)
368@serializable_dataclass
369class nested_custom(SerializableDataclass):
370 value: float
371 data1: Custom_class_with_serialization
374def test_nested_custom(recwarn): # this will send some warnings but whatever
375 instance = nested_custom(
376 value=42.0, data1=Custom_class_with_serialization(1, "hello")
377 )
378 serialized = instance.serialize()
379 expected_ser = {
380 "value": 42.0,
381 "data1": {"a": 1, "b": "hello"},
382 _FORMAT_KEY: "nested_custom(SerializableDataclass)",
383 }
384 assert serialized == expected_ser
385 loaded = nested_custom.load(serialized)
386 assert loaded == instance
389def test_deserialize_fn():
390 @serializable_dataclass
391 class DeserializeFn(SerializableDataclass):
392 data: int = serializable_field(
393 serialization_fn=lambda x: str(x),
394 deserialize_fn=lambda x: int(x),
395 )
397 instance = DeserializeFn(data=5)
398 serialized = instance.serialize()
399 assert serialized == {
400 "data": "5",
401 _FORMAT_KEY: "DeserializeFn(SerializableDataclass)",
402 }
404 loaded = DeserializeFn.load(serialized)
405 assert loaded == instance
406 assert loaded.data == 5
409@serializable_dataclass
410class DictContainer(SerializableDataclass):
411 """Test class containing a dictionary field"""
413 simple_dict: Dict[str, int]
414 nested_dict: Dict[str, Dict[str, int]] = serializable_field(default_factory=dict)
415 optional_dict: Dict[str, str] = serializable_field(default_factory=dict)
418def test_dict_serialization():
419 """Test serialization of dictionaries within SerializableDataclass"""
420 data = DictContainer(
421 simple_dict={"a": 1, "b": 2},
422 nested_dict={"x": {"y": 3, "z": 4}},
423 optional_dict={"hello": "world"},
424 )
426 serialized = data.serialize()
427 expected = {
428 _FORMAT_KEY: "DictContainer(SerializableDataclass)",
429 "simple_dict": {"a": 1, "b": 2},
430 "nested_dict": {"x": {"y": 3, "z": 4}},
431 "optional_dict": {"hello": "world"},
432 }
434 assert serialized == expected
437def test_dict_loading():
438 """Test loading dictionaries into SerializableDataclass"""
439 original_data = {
440 _FORMAT_KEY: "DictContainer(SerializableDataclass)",
441 "simple_dict": {"a": 1, "b": 2},
442 "nested_dict": {"x": {"y": 3, "z": 4}},
443 "optional_dict": {"hello": "world"},
444 }
446 loaded = DictContainer.load(original_data)
447 assert loaded.simple_dict == {"a": 1, "b": 2}
448 assert loaded.nested_dict == {"x": {"y": 3, "z": 4}}
449 assert loaded.optional_dict == {"hello": "world"}
452def test_dict_equality():
453 """Test equality comparison of dictionaries within SerializableDataclass"""
454 instance1 = DictContainer(
455 simple_dict={"a": 1, "b": 2},
456 nested_dict={"x": {"y": 3, "z": 4}},
457 optional_dict={"hello": "world"},
458 )
460 instance2 = DictContainer(
461 simple_dict={"a": 1, "b": 2},
462 nested_dict={"x": {"y": 3, "z": 4}},
463 optional_dict={"hello": "world"},
464 )
466 instance3 = DictContainer(
467 simple_dict={"a": 1, "b": 3}, # Different value
468 nested_dict={"x": {"y": 3, "z": 4}},
469 optional_dict={"hello": "world"},
470 )
472 assert instance1 == instance2
473 assert instance1 != instance3
474 assert instance2 != instance3
477def test_dict_diff():
478 """Test diff functionality with dictionaries"""
479 instance1 = DictContainer(
480 simple_dict={"a": 1, "b": 2},
481 nested_dict={"x": {"y": 3, "z": 4}},
482 optional_dict={"hello": "world"},
483 )
485 # Different simple_dict value
486 instance2 = DictContainer(
487 simple_dict={"a": 1, "b": 3},
488 nested_dict={"x": {"y": 3, "z": 4}},
489 optional_dict={"hello": "world"},
490 )
492 # Different nested_dict value
493 instance3 = DictContainer(
494 simple_dict={"a": 1, "b": 2},
495 nested_dict={"x": {"y": 3, "z": 5}},
496 optional_dict={"hello": "world"},
497 )
499 # Different optional_dict value
500 instance4 = DictContainer(
501 simple_dict={"a": 1, "b": 2},
502 nested_dict={"x": {"y": 3, "z": 4}},
503 optional_dict={"hello": "python"},
504 )
506 # Test diff with simple_dict changes
507 diff1 = instance1.diff(instance2)
508 assert diff1 == {
509 "simple_dict": {"self": {"a": 1, "b": 2}, "other": {"a": 1, "b": 3}}
510 }
512 # Test diff with nested_dict changes
513 diff2 = instance1.diff(instance3)
514 assert diff2 == {
515 "nested_dict": {
516 "self": {"x": {"y": 3, "z": 4}},
517 "other": {"x": {"y": 3, "z": 5}},
518 }
519 }
521 # Test diff with optional_dict changes
522 diff3 = instance1.diff(instance4)
523 assert diff3 == {
524 "optional_dict": {"self": {"hello": "world"}, "other": {"hello": "python"}}
525 }
527 # Test no diff when comparing identical instances
528 assert instance1.diff(instance1) == {}
531@serializable_dataclass
532class ComplexDictContainer(SerializableDataclass):
533 """Test class with more complex dictionary structures"""
535 mixed_dict: Dict[str, Any]
536 list_dict: Dict[str, typing.List[int]]
537 multi_nested: Dict[str, Dict[str, Dict[str, int]]]
540def test_complex_dict_serialization():
541 """Test serialization of more complex dictionary structures"""
542 data = ComplexDictContainer(
543 mixed_dict={"str": "hello", "int": 42, "list": [1, 2, 3]},
544 list_dict={"a": [1, 2, 3], "b": [4, 5, 6]},
545 multi_nested={"x": {"y": {"z": 1, "w": 2}, "v": {"u": 3, "t": 4}}},
546 )
548 serialized = data.serialize()
549 loaded = ComplexDictContainer.load(serialized)
550 assert loaded == data
551 assert loaded.diff(data) == {}
554def test_empty_dicts():
555 """Test handling of empty dictionaries"""
556 data = DictContainer(simple_dict={}, nested_dict={}, optional_dict={})
558 serialized = data.serialize()
559 loaded = DictContainer.load(serialized)
560 assert loaded == data
561 assert loaded.diff(data) == {}
563 # Test equality with another empty instance
564 another_empty = DictContainer(simple_dict={}, nested_dict={}, optional_dict={})
565 assert data == another_empty
568# Test invalid dictionary type validation
569@serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT)
570class StrictDictContainer(SerializableDataclass):
571 """Test class with strict dictionary typing"""
573 int_dict: Dict[str, int]
574 str_dict: Dict[str, str]
575 float_dict: Dict[str, float]
578# TODO: figure this out
579@pytest.mark.skip(reason="dict type validation doesnt seem to work")
580def test_dict_type_validation():
581 """Test type validation for dictionary values"""
582 # Valid case
583 valid = StrictDictContainer(
584 int_dict={"a": 1, "b": 2},
585 str_dict={"x": "hello", "y": "world"},
586 float_dict={"m": 1.0, "n": 2.5},
587 )
588 assert valid.validate_fields_types()
590 # Invalid int_dict
591 with pytest.raises(FieldTypeMismatchError):
592 StrictDictContainer(
593 int_dict={"a": "not an int"}, # type: ignore[dict-item]
594 str_dict={"x": "hello"},
595 float_dict={"m": 1.0},
596 )
598 # Invalid str_dict
599 with pytest.raises(FieldTypeMismatchError):
600 StrictDictContainer(
601 int_dict={"a": 1},
602 str_dict={"x": 123}, # type: ignore[dict-item]
603 float_dict={"m": 1.0},
604 )
607# Test dictionary with optional values
608@serializable_dataclass
609class OptionalDictContainer(SerializableDataclass):
610 """Test class with optional dictionary values"""
612 optional_values: Dict[str, Optional[int]]
613 union_values: Dict[str, Union[int, str]]
614 nullable_dict: Optional[Dict[str, int]] = None
617def test_optional_dict_values():
618 """Test dictionaries with optional/union values"""
619 instance = OptionalDictContainer(
620 optional_values={"a": 1, "b": None, "c": 3},
621 union_values={"x": 1, "y": "string", "z": 42},
622 nullable_dict={"m": 1, "n": 2},
623 )
625 serialized = instance.serialize()
626 loaded = OptionalDictContainer.load(serialized)
627 assert loaded == instance
629 # Test with None dict
630 instance2 = OptionalDictContainer(
631 optional_values={"a": None, "b": None},
632 union_values={"x": "all strings", "y": "here"},
633 nullable_dict=None,
634 )
636 serialized2 = instance2.serialize()
637 loaded2 = OptionalDictContainer.load(serialized2)
638 assert loaded2 == instance2
641# Test dictionary mutation
642def test_dict_mutation():
643 """Test behavior when mutating dictionary contents"""
644 instance1 = DictContainer(
645 simple_dict={"a": 1, "b": 2},
646 nested_dict={"x": {"y": 3}},
647 optional_dict={"hello": "world"},
648 )
650 instance2 = deepcopy(instance1)
652 # Mutate dictionary in instance1
653 instance1.simple_dict["c"] = 3
654 instance1.nested_dict["x"]["z"] = 4
655 instance1.optional_dict["new"] = "value"
657 # Verify instance2 was not affected
658 assert instance2.simple_dict == {"a": 1, "b": 2}
659 assert instance2.nested_dict == {"x": {"y": 3}}
660 assert instance2.optional_dict == {"hello": "world"}
662 # Verify diff shows the changes
663 diff = instance2.diff(instance1)
664 assert "simple_dict" in diff
665 assert "nested_dict" in diff
666 assert "optional_dict" in diff
669# Test dictionary key types
670@serializable_dataclass
671class IntKeyDictContainer(SerializableDataclass):
672 """Test class with non-string dictionary keys"""
674 int_keys: Dict[int, str] = serializable_field(
675 serialization_fn=lambda x: {str(k): v for k, v in x.items()},
676 loading_fn=lambda x: {int(k): v for k, v in x["int_keys"].items()},
677 )
680def test_non_string_dict_keys():
681 """Test handling of dictionaries with non-string keys"""
682 instance = IntKeyDictContainer(int_keys={1: "one", 2: "two", 3: "three"})
684 serialized = instance.serialize()
685 # Keys should be converted to strings in serialized form
686 assert all(isinstance(k, str) for k in serialized["int_keys"].keys())
688 loaded = IntKeyDictContainer.load(serialized)
689 # Keys should be integers again after loading
690 assert all(isinstance(k, int) for k in loaded.int_keys.keys())
691 assert loaded == instance
694@serializable_dataclass
695class RecursiveDictContainer(SerializableDataclass):
696 """Test class with recursively defined dictionary type"""
698 data: Dict[str, Any]
701def test_recursive_dict_structure():
702 """Test handling of recursively nested dictionaries"""
703 deep_dict = {
704 "level1": {
705 "level2": {"level3": {"value": 42, "list": [1, 2, {"nested": "value"}]}}
706 }
707 }
709 instance = RecursiveDictContainer(data=deep_dict)
710 serialized = instance.serialize()
711 loaded = RecursiveDictContainer.load(serialized)
713 assert loaded == instance
714 assert loaded.data == deep_dict
717# need to define this outside, otherwise the validator cant see it?
718class CustomSerializable:
719 def __init__(self, value):
720 self.value: Union[str, int] = value
722 def serialize(self):
723 return {"value": self.value}
725 @classmethod
726 def load(cls, data):
727 return cls(data["value"])
729 def __eq__(self, other):
730 return isinstance(other, CustomSerializable) and self.value == other.value
733def test_dict_with_custom_objects():
734 """Test dictionaries containing custom objects that implement serialize/load"""
736 @serializable_dataclass
737 class CustomObjectDict(SerializableDataclass):
738 data: Dict[str, CustomSerializable] = serializable_field()
740 instance: CustomObjectDict = CustomObjectDict(
741 data={"a": CustomSerializable(42), "b": CustomSerializable("hello")}
742 )
744 assert isinstance(instance, CustomObjectDict)
745 assert isinstance(instance.data, dict)
746 assert isinstance(instance.data["a"], CustomSerializable)
747 assert isinstance(instance.data["a"].value, int)
748 assert isinstance(instance.data["b"], CustomSerializable)
749 assert isinstance(instance.data["b"].value, str)
751 serialized = instance.serialize()
752 loaded = CustomObjectDict.load(serialized)
753 assert loaded == instance
756def test_empty_optional_dicts():
757 """Test handling of None vs empty dict in optional dictionary fields"""
759 @serializable_dataclass
760 class OptionalDictFields(SerializableDataclass):
761 required_dict: Dict[str, int]
762 optional_dict: Optional[Dict[str, int]] = None
763 default_empty: Dict[str, int] = serializable_field(default_factory=dict)
765 # Test with None
766 instance1 = OptionalDictFields(required_dict={"a": 1}, optional_dict=None)
768 # Test with empty dict
769 instance2 = OptionalDictFields(required_dict={"a": 1}, optional_dict={})
771 serialized1 = instance1.serialize()
772 serialized2 = instance2.serialize()
774 loaded1 = OptionalDictFields.load(serialized1)
775 loaded2 = OptionalDictFields.load(serialized2)
777 assert loaded1.optional_dict is None
778 assert loaded2.optional_dict == {}
779 assert loaded1.default_empty == {}
780 assert loaded2.default_empty == {}
783# Test inheritance hierarchies
784@serializable_dataclass(
785 on_typecheck_error=ErrorMode.EXCEPT, on_typecheck_mismatch=ErrorMode.EXCEPT
786)
787class BaseClass(SerializableDataclass):
788 """Base class for testing inheritance"""
790 base_field: str
791 shared_field: int = serializable_field(default=0)
794@serializable_dataclass
795class ChildClass(BaseClass):
796 """Child class inheriting from BaseClass"""
798 child_field: float = serializable_field(default=0.1)
799 shared_field: int = serializable_field(default=1) # Override base class field
802@serializable_dataclass
803class GrandchildClass(ChildClass):
804 """Grandchild class for deep inheritance testing"""
806 grandchild_field: bool = serializable_field(default=True)
809def test_inheritance():
810 """Test inheritance behavior of serializable dataclasses"""
811 instance = GrandchildClass(
812 base_field="base", shared_field=42, child_field=3.14, grandchild_field=True
813 )
815 serialized = instance.serialize()
816 assert serialized["base_field"] == "base"
817 assert serialized["shared_field"] == 42
818 assert serialized["child_field"] == 3.14
819 assert serialized["grandchild_field"] is True
821 loaded = GrandchildClass.load(serialized)
822 assert loaded == instance
824 # Test that we can load as parent class
825 base_loaded = BaseClass.load({"base_field": "test", "shared_field": 1})
826 assert isinstance(base_loaded, BaseClass)
827 assert not isinstance(base_loaded, ChildClass)
830@pytest.mark.skip(
831 reason="Not implemented yet, generic types not supported and throw a `TypeHintNotImplementedError`"
832)
833def test_generic_types():
834 """Test handling of generic type parameters"""
836 T = TypeVar("T")
838 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT)
839 class GenericContainer(SerializableDataclass, Generic[T]):
840 """Test generic type parameters"""
842 value: T
843 values: List[T]
845 # Test with int
846 int_container = GenericContainer[int](value=42, values=[1, 2, 3])
847 serialized = int_container.serialize()
848 loaded = GenericContainer[int].load(serialized)
849 assert loaded == int_container
851 # Test with str
852 str_container = GenericContainer[str](value="hello", values=["a", "b", "c"])
853 serialized2 = str_container.serialize()
854 loaded2 = GenericContainer[str].load(serialized2)
855 assert loaded2 == str_container
858# Test custom serialization/deserialization
859class CustomObject:
860 def __init__(self, value):
861 self.value = value
863 def __eq__(self, other):
864 return isinstance(other, CustomObject) and self.value == other.value
867@serializable_dataclass
868class CustomSerializationContainer(SerializableDataclass):
869 """Test custom serialization functions"""
871 custom_obj: CustomObject = serializable_field(
872 serialization_fn=lambda x: x.value,
873 loading_fn=lambda x: CustomObject(x["custom_obj"]),
874 )
875 transform_field: int = serializable_field(
876 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["transform_field"] // 2
877 )
880def test_custom_serialization_2():
881 """Test custom serialization and loading functions"""
882 instance = CustomSerializationContainer(
883 custom_obj=CustomObject(42), transform_field=10
884 )
886 serialized = instance.serialize()
887 assert serialized["custom_obj"] == 42
888 assert serialized["transform_field"] == 20
890 loaded = CustomSerializationContainer.load(serialized)
891 assert loaded == instance
892 assert loaded.transform_field == 10
895# @pytest.mark.skip(reason="Not implemented yet, waiting on `custom_value_check_fn`")
896# def test_value_validation():
897# """Test field validation"""
898# @serializable_dataclass
899# class ValidationContainer(SerializableDataclass):
900# """Test validation and error handling"""
901# positive_int: int = serializable_field(
902# custom_value_check_fn=lambda x: x > 0
903# )
904# email: str = serializable_field(
905# custom_value_check_fn=lambda x: '@' in x
906# )
908# # Valid case
909# valid = ValidationContainer(positive_int=42, email="test@example.com")
910# assert valid.validate_fields_types()
912# # what will this do?
913# maybe_valid = ValidationContainer(positive_int=4.2, email="test@example.com")
914# assert maybe_valid.validate_fields_types()
916# maybe_valid_2 = ValidationContainer(positive_int=42, email=["test", "@", "example", ".com"])
917# assert maybe_valid_2.validate_fields_types()
919# # Invalid positive_int
920# with pytest.raises(ValueError):
921# ValidationContainer(positive_int=-1, email="test@example.com")
923# # Invalid email
924# with pytest.raises(ValueError):
925# ValidationContainer(positive_int=42, email="invalid")
928def test_init_true_serialize_false():
929 with pytest.raises(ValueError):
931 @serializable_dataclass
932 class MetadataContainer(SerializableDataclass):
933 """Test field metadata and options"""
935 hidden: str = serializable_field(serialize=False, init=True)
936 readonly: int = serializable_field(init=True, frozen=True)
937 computed: float = serializable_field(init=False, serialize=True)
939 def __post_init__(self):
940 object.__setattr__(self, "computed", self.readonly * 2.0)
943# Test property serialization
944@serializable_dataclass(properties_to_serialize=["full_name", "age_in_months"])
945class PropertyContainer(SerializableDataclass):
946 """Test property serialization"""
948 first_name: str
949 last_name: str
950 age_years: int
952 @property
953 def full_name(self) -> str:
954 return f"{self.first_name} {self.last_name}"
956 @property
957 def age_in_months(self) -> int:
958 return self.age_years * 12
961def test_property_serialization():
962 """Test serialization of properties"""
963 instance = PropertyContainer(first_name="John", last_name="Doe", age_years=30)
965 serialized = instance.serialize()
966 assert serialized["full_name"] == "John Doe"
967 assert serialized["age_in_months"] == 360
969 loaded = PropertyContainer.load(serialized)
970 assert loaded == instance
973# TODO: this would be nice to fix, but not a massive issue
974@pytest.mark.skip(reason="Not implemented yet")
975def test_edge_cases():
976 """Test a sdc containing instances of itself"""
978 @serializable_dataclass
979 class EdgeCaseContainer(SerializableDataclass):
980 """Test edge cases and corner cases"""
982 empty_list: List[Any] = serializable_field(default_factory=list)
983 optional_value: Optional[int] = serializable_field(default=None)
984 union_field: Union[str, int, None] = serializable_field(default=None)
985 recursive_ref: Optional["EdgeCaseContainer"] = serializable_field(default=None)
987 # Test recursive structure
988 nested = EdgeCaseContainer()
989 instance = EdgeCaseContainer(recursive_ref=nested)
991 serialized = instance.serialize()
992 loaded = EdgeCaseContainer.load(serialized)
993 assert loaded == instance
995 # Test empty/None handling
996 empty = EdgeCaseContainer()
997 assert empty.empty_list == []
998 assert empty.optional_value is None
999 assert empty.union_field is None
1001 # Test union field with different types
1002 instance.union_field = "string"
1003 serialized = instance.serialize()
1004 loaded = EdgeCaseContainer.load(serialized)
1005 assert loaded.union_field == "string"
1007 instance.union_field = 42
1008 serialized = instance.serialize()
1009 loaded = EdgeCaseContainer.load(serialized)
1010 assert loaded.union_field == 42
1013# Test error handling for malformed data
1014def test_error_handling():
1015 """Test error handling for malformed data"""
1016 # Missing required field
1017 with pytest.raises(TypeError):
1018 BaseClass.load({})
1020 x = BaseClass(base_field=42, shared_field="invalid") # type: ignore[arg-type]
1021 assert not x.validate_fields_types()
1023 with pytest.raises(FieldTypeMismatchError):
1024 BaseClass.load(
1025 {
1026 "base_field": 42, # Should be str
1027 "shared_field": "invalid", # Should be int
1028 }
1029 )
1031 # Invalid format string
1032 # with pytest.raises(ValueError):
1033 # BaseClass.load({
1034 # _FORMAT_KEY: "InvalidClass(SerializableDataclass)",
1035 # "base_field": "test",
1036 # "shared_field": 0
1037 # })
1040# Test for memory leaks and cyclic references
1041# TODO: make .serialize() fail on cyclic references! see https://github.com/mivanit/muutils/issues/62
1042@pytest.mark.skip(reason="Not implemented yet")
1043def test_cyclic_references():
1044 """Test handling of cyclic references"""
1046 @serializable_dataclass
1047 class Node(SerializableDataclass):
1048 value: str
1049 next: Optional["Node"] = serializable_field(default=None)
1051 # Create a cycle
1052 node1 = Node("one")
1053 node2 = Node("two")
1054 node1.next = node2
1055 node2.next = node1
1057 # Ensure we can serialize without infinite recursion
1058 serialized = node1.serialize()
1059 loaded = Node.load(serialized)
1060 assert loaded.value == "one"
1061 # TODO: idk why we type ignore here
1062 assert loaded.next.value == "two" # type: ignore[union-attr]