Coverage for tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py: 99%
113 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1from typing import Any
2import typing
4import pytest
6# Import the decorator and base class.
7# (Adjust the import below to match your project structure.)
8from muutils.json_serialize import (
9 serializable_dataclass,
10 SerializableDataclass,
11)
12from muutils.json_serialize.util import _FORMAT_KEY
15@serializable_dataclass
16class SimpleClass(SerializableDataclass):
17 a: int
18 b: str
21def test_simple_class_serialization():
22 simple = SimpleClass(a=42, b="hello")
23 serialized = simple.serialize()
24 assert serialized == {
25 "a": 42,
26 "b": "hello",
27 _FORMAT_KEY: "SimpleClass(SerializableDataclass)",
28 }
30 loaded = SimpleClass.load(serialized)
31 assert loaded == simple
34def test_default_overrides():
35 """Test that by default the decorator overrides __eq__, serialize, load, and validate_fields_types."""
37 @serializable_dataclass
38 class DefaultClass(SerializableDataclass):
39 a: int
40 b: str
42 # Instantiate and serialize.
43 obj: DefaultClass = DefaultClass(a=1, b="test")
44 ser: typing.Dict[str, Any] = obj.serialize()
45 # Check that the _FORMAT_KEY is present with the correct value.
46 assert ser.get(_FORMAT_KEY) == "DefaultClass(SerializableDataclass)"
47 assert ser.get("a") == 1
48 assert ser.get("b") == "test"
50 # Test load method: re-create the object from its serialization.
51 loaded: DefaultClass = DefaultClass.load(ser)
52 # Equality is provided by the decorator (via dc_eq).
53 assert loaded == obj
55 # Check that validate_fields_types works (should be True with correct types).
56 assert obj.validate_fields_types() is True
58 # Test __eq__ by comparing two instances with same values.
59 obj2: DefaultClass = DefaultClass(a=1, b="test")
60 assert obj == obj2
61 obj3: DefaultClass = DefaultClass(a=2, b="test")
62 assert obj != obj3
65def test_no_override_serialize():
66 """Test that specifying 'serialize' in methods_no_override preserves the user-defined serialize method."""
68 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False)
69 class NoSerializeClass(SerializableDataclass):
70 a: int
72 def serialize(self) -> typing.Dict[str, Any]:
73 # Custom serialization (ignoring the _FORMAT_KEY mechanism)
74 return {"custom": self.a}
76 obj: NoSerializeClass = NoSerializeClass(a=42)
77 ser: typing.Dict[str, Any] = obj.serialize()
78 # The custom serialize should be preserved.
79 assert ser == {"custom": 42}
81 # The load method is still provided by the decorator.
82 loaded: NoSerializeClass = NoSerializeClass.load({"a": 42})
83 # Since load uses the type hints to call the constructor, it will create an instance with a==42.
84 assert loaded.a == 42
86 # __eq__ should be the decorator's version (via dc_eq).
87 obj2: NoSerializeClass = NoSerializeClass(a=42)
88 assert obj == obj2
91def test_no_override_eq_and_serialize():
92 """Test that specifying both '__eq__' and 'serialize' in methods_no_override preserves the user-defined methods,
93 while load and validate_fields_types are still overridden."""
95 @serializable_dataclass(
96 methods_no_override=["__eq__", "serialize"], register_handler=False
97 )
98 class NoEqSerializeClass(SerializableDataclass):
99 a: int
101 def __eq__(self, other: Any) -> bool:
102 if not isinstance(other, NoEqSerializeClass):
103 return False
104 # Custom equality: only compare the 'a' attribute.
105 return self.a == other.a
107 def serialize(self) -> typing.Dict[str, Any]:
108 return {"custom_serialize": self.a}
110 obj1: NoEqSerializeClass = NoEqSerializeClass(a=100)
111 obj2: NoEqSerializeClass = NoEqSerializeClass(a=100)
112 obj3: NoEqSerializeClass = NoEqSerializeClass(a=200)
114 # Check that the custom serialize is used.
115 assert obj1.serialize() == {"custom_serialize": 100}
117 # Check that the custom __eq__ is used.
118 assert obj1 == obj2
119 assert obj1 != obj3
121 # The load method should be the decorator's version.
122 loaded: NoEqSerializeClass = NoEqSerializeClass.load({"a": 100})
123 assert loaded.a == 100
125 # validate_fields_types should be provided by the decorator.
126 assert obj1.validate_fields_types() is True
129def test_inheritance_override():
130 """Test behavior when inheritance is involved:
131 - A base class with a custom serialize (preserved via methods_no_override)
132 and a subclass that does not preserve it gets the decorator's version.
133 - A subclass that preserves 'serialize' keeps the custom method from the base.
134 """
136 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False)
137 class BaseClass(SerializableDataclass):
138 a: int
140 def serialize(self) -> typing.Dict[str, Any]:
141 return {"base": self.a}
143 # SubClass without preserving 'serialize': decorator will override.
144 @serializable_dataclass(register_handler=False)
145 class SubClass(BaseClass):
146 b: int
148 # SubClassPreserve preserves serialize (and hence inherits BaseClass.serialize).
149 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False)
150 class SubClassPreserve(BaseClass):
151 b: int
153 base_obj: BaseClass = BaseClass(a=10)
154 # Custom serialize from BaseClass is preserved.
155 assert base_obj.serialize() == {"base": 10}
157 sub_obj: SubClass = SubClass(a=1, b=2)
158 ser_sub: typing.Dict[str, Any] = sub_obj.serialize()
159 # Since SubClass does not preserve serialize, it gets the decorator version.
160 # It will include the _FORMAT_KEY and the field values.
161 assert _FORMAT_KEY in ser_sub
162 assert ser_sub.get("a") == 1
163 assert ser_sub.get("b") == 2
165 sub_preserve: SubClassPreserve = SubClassPreserve(a=20, b=30)
166 # This subclass preserves its custom (inherited) serialize from BaseClass.
167 assert sub_preserve.serialize() == {"base": 20}
170def test_polymorphic_behavior():
171 """Test that polymorphic classes can use different serialize implementations based on methods_no_override."""
173 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False)
174 class PolyA(SerializableDataclass):
175 a: int
177 def serialize(self) -> typing.Dict[str, Any]:
178 return {"poly_a": self.a}
180 @serializable_dataclass(register_handler=False)
181 class PolyB(SerializableDataclass):
182 b: int
184 a_obj: PolyA = PolyA(a=5)
185 b_obj: PolyB = PolyB(b=15)
187 # PolyA uses its custom serialize.
188 assert a_obj.serialize() == {"poly_a": 5}
190 # PolyB uses the default decorator-provided serialize.
191 ser_b: typing.Dict[str, Any] = b_obj.serialize()
192 assert ser_b.get(_FORMAT_KEY) == "PolyB(SerializableDataclass)"
193 assert ser_b.get("b") == 15
195 # Equality and load should work polymorphically.
196 a_loaded: PolyA = PolyA.load({"a": 5})
197 b_loaded: PolyB = PolyB.load({"b": 15})
198 assert a_loaded.a == 5
199 assert b_loaded.b == 15
202def test_unknown_methods_warning():
203 """Test that if unknown method names are passed to methods_no_override, a warning is issued."""
204 with pytest.warns(UserWarning, match="Unknown methods in `methods_no_override`"):
206 @serializable_dataclass(
207 methods_no_override=["non_existing_method"], register_handler=False
208 )
209 class UnknownMethodClass(SerializableDataclass):
210 a: int
212 # Even though the warning is raised, the class should still work normally.
213 obj: UnknownMethodClass = UnknownMethodClass(a=999)
214 ser: typing.Dict[str, Any] = obj.serialize()
215 # Since "serialize" was not preserved, decorator provides it.
216 assert _FORMAT_KEY in ser
217 assert ser.get("a") == 999