Coverage for tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py: 99%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from typing import Any 

2import typing 

3 

4import pytest 

5 

6# Import the decorator and base class. 

7# (Adjust the import below to match your project structure.) 

8from muutils.json_serialize import ( 

9 serializable_dataclass, 

10 SerializableDataclass, 

11) 

12from muutils.json_serialize.util import _FORMAT_KEY 

13 

14 

15@serializable_dataclass 

16class SimpleClass(SerializableDataclass): 

17 a: int 

18 b: str 

19 

20 

21def test_simple_class_serialization(): 

22 simple = SimpleClass(a=42, b="hello") 

23 serialized = simple.serialize() 

24 assert serialized == { 

25 "a": 42, 

26 "b": "hello", 

27 _FORMAT_KEY: "SimpleClass(SerializableDataclass)", 

28 } 

29 

30 loaded = SimpleClass.load(serialized) 

31 assert loaded == simple 

32 

33 

34def test_default_overrides(): 

35 """Test that by default the decorator overrides __eq__, serialize, load, and validate_fields_types.""" 

36 

37 @serializable_dataclass 

38 class DefaultClass(SerializableDataclass): 

39 a: int 

40 b: str 

41 

42 # Instantiate and serialize. 

43 obj: DefaultClass = DefaultClass(a=1, b="test") 

44 ser: typing.Dict[str, Any] = obj.serialize() 

45 # Check that the _FORMAT_KEY is present with the correct value. 

46 assert ser.get(_FORMAT_KEY) == "DefaultClass(SerializableDataclass)" 

47 assert ser.get("a") == 1 

48 assert ser.get("b") == "test" 

49 

50 # Test load method: re-create the object from its serialization. 

51 loaded: DefaultClass = DefaultClass.load(ser) 

52 # Equality is provided by the decorator (via dc_eq). 

53 assert loaded == obj 

54 

55 # Check that validate_fields_types works (should be True with correct types). 

56 assert obj.validate_fields_types() is True 

57 

58 # Test __eq__ by comparing two instances with same values. 

59 obj2: DefaultClass = DefaultClass(a=1, b="test") 

60 assert obj == obj2 

61 obj3: DefaultClass = DefaultClass(a=2, b="test") 

62 assert obj != obj3 

63 

64 

65def test_no_override_serialize(): 

66 """Test that specifying 'serialize' in methods_no_override preserves the user-defined serialize method.""" 

67 

68 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False) 

69 class NoSerializeClass(SerializableDataclass): 

70 a: int 

71 

72 def serialize(self) -> typing.Dict[str, Any]: 

73 # Custom serialization (ignoring the _FORMAT_KEY mechanism) 

74 return {"custom": self.a} 

75 

76 obj: NoSerializeClass = NoSerializeClass(a=42) 

77 ser: typing.Dict[str, Any] = obj.serialize() 

78 # The custom serialize should be preserved. 

79 assert ser == {"custom": 42} 

80 

81 # The load method is still provided by the decorator. 

82 loaded: NoSerializeClass = NoSerializeClass.load({"a": 42}) 

83 # Since load uses the type hints to call the constructor, it will create an instance with a==42. 

84 assert loaded.a == 42 

85 

86 # __eq__ should be the decorator's version (via dc_eq). 

87 obj2: NoSerializeClass = NoSerializeClass(a=42) 

88 assert obj == obj2 

89 

90 

91def test_no_override_eq_and_serialize(): 

92 """Test that specifying both '__eq__' and 'serialize' in methods_no_override preserves the user-defined methods, 

93 while load and validate_fields_types are still overridden.""" 

94 

95 @serializable_dataclass( 

96 methods_no_override=["__eq__", "serialize"], register_handler=False 

97 ) 

98 class NoEqSerializeClass(SerializableDataclass): 

99 a: int 

100 

101 def __eq__(self, other: Any) -> bool: 

102 if not isinstance(other, NoEqSerializeClass): 

103 return False 

104 # Custom equality: only compare the 'a' attribute. 

105 return self.a == other.a 

106 

107 def serialize(self) -> typing.Dict[str, Any]: 

108 return {"custom_serialize": self.a} 

109 

110 obj1: NoEqSerializeClass = NoEqSerializeClass(a=100) 

111 obj2: NoEqSerializeClass = NoEqSerializeClass(a=100) 

112 obj3: NoEqSerializeClass = NoEqSerializeClass(a=200) 

113 

114 # Check that the custom serialize is used. 

115 assert obj1.serialize() == {"custom_serialize": 100} 

116 

117 # Check that the custom __eq__ is used. 

118 assert obj1 == obj2 

119 assert obj1 != obj3 

120 

121 # The load method should be the decorator's version. 

122 loaded: NoEqSerializeClass = NoEqSerializeClass.load({"a": 100}) 

123 assert loaded.a == 100 

124 

125 # validate_fields_types should be provided by the decorator. 

126 assert obj1.validate_fields_types() is True 

127 

128 

129def test_inheritance_override(): 

130 """Test behavior when inheritance is involved: 

131 - A base class with a custom serialize (preserved via methods_no_override) 

132 and a subclass that does not preserve it gets the decorator's version. 

133 - A subclass that preserves 'serialize' keeps the custom method from the base. 

134 """ 

135 

136 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False) 

137 class BaseClass(SerializableDataclass): 

138 a: int 

139 

140 def serialize(self) -> typing.Dict[str, Any]: 

141 return {"base": self.a} 

142 

143 # SubClass without preserving 'serialize': decorator will override. 

144 @serializable_dataclass(register_handler=False) 

145 class SubClass(BaseClass): 

146 b: int 

147 

148 # SubClassPreserve preserves serialize (and hence inherits BaseClass.serialize). 

149 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False) 

150 class SubClassPreserve(BaseClass): 

151 b: int 

152 

153 base_obj: BaseClass = BaseClass(a=10) 

154 # Custom serialize from BaseClass is preserved. 

155 assert base_obj.serialize() == {"base": 10} 

156 

157 sub_obj: SubClass = SubClass(a=1, b=2) 

158 ser_sub: typing.Dict[str, Any] = sub_obj.serialize() 

159 # Since SubClass does not preserve serialize, it gets the decorator version. 

160 # It will include the _FORMAT_KEY and the field values. 

161 assert _FORMAT_KEY in ser_sub 

162 assert ser_sub.get("a") == 1 

163 assert ser_sub.get("b") == 2 

164 

165 sub_preserve: SubClassPreserve = SubClassPreserve(a=20, b=30) 

166 # This subclass preserves its custom (inherited) serialize from BaseClass. 

167 assert sub_preserve.serialize() == {"base": 20} 

168 

169 

170def test_polymorphic_behavior(): 

171 """Test that polymorphic classes can use different serialize implementations based on methods_no_override.""" 

172 

173 @serializable_dataclass(methods_no_override=["serialize"], register_handler=False) 

174 class PolyA(SerializableDataclass): 

175 a: int 

176 

177 def serialize(self) -> typing.Dict[str, Any]: 

178 return {"poly_a": self.a} 

179 

180 @serializable_dataclass(register_handler=False) 

181 class PolyB(SerializableDataclass): 

182 b: int 

183 

184 a_obj: PolyA = PolyA(a=5) 

185 b_obj: PolyB = PolyB(b=15) 

186 

187 # PolyA uses its custom serialize. 

188 assert a_obj.serialize() == {"poly_a": 5} 

189 

190 # PolyB uses the default decorator-provided serialize. 

191 ser_b: typing.Dict[str, Any] = b_obj.serialize() 

192 assert ser_b.get(_FORMAT_KEY) == "PolyB(SerializableDataclass)" 

193 assert ser_b.get("b") == 15 

194 

195 # Equality and load should work polymorphically. 

196 a_loaded: PolyA = PolyA.load({"a": 5}) 

197 b_loaded: PolyB = PolyB.load({"b": 15}) 

198 assert a_loaded.a == 5 

199 assert b_loaded.b == 15 

200 

201 

202def test_unknown_methods_warning(): 

203 """Test that if unknown method names are passed to methods_no_override, a warning is issued.""" 

204 with pytest.warns(UserWarning, match="Unknown methods in `methods_no_override`"): 

205 

206 @serializable_dataclass( 

207 methods_no_override=["non_existing_method"], register_handler=False 

208 ) 

209 class UnknownMethodClass(SerializableDataclass): 

210 a: int 

211 

212 # Even though the warning is raised, the class should still work normally. 

213 obj: UnknownMethodClass = UnknownMethodClass(a=999) 

214 ser: typing.Dict[str, Any] = obj.serialize() 

215 # Since "serialize" was not preserved, decorator provides it. 

216 assert _FORMAT_KEY in ser 

217 assert ser.get("a") == 999