Coverage for tests / unit / json_serialize / test_serializable_field.py: 98%

182 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1"""Tests for muutils.json_serialize.serializable_field module. 

2 

3Tests the SerializableField class and serializable_field function, 

4which extend dataclasses.Field with serialization capabilities. 

5""" 

6 

7from __future__ import annotations 

8 

9import dataclasses 

10from dataclasses import field 

11from typing import Any, Tuple 

12 

13import pytest 

14 

15from muutils.json_serialize import ( 

16 SerializableDataclass, 

17 serializable_dataclass, 

18 serializable_field, 

19) 

20from muutils.json_serialize.serializable_field import SerializableField 

21 

22 

23# ============================================================================ 

24# Test SerializableField creation with various parameters 

25# ============================================================================ 

26 

27 

28def test_SerializableField_creation(): 

29 """Test creating SerializableField with various parameters.""" 

30 # Basic creation with default parameters 

31 sf1 = SerializableField() 

32 assert sf1.serialize is True 

33 assert sf1.serialization_fn is None 

34 assert sf1.loading_fn is None 

35 assert sf1.deserialize_fn is None 

36 assert sf1.assert_type is True 

37 assert sf1.custom_typecheck_fn is None 

38 assert sf1.default is dataclasses.MISSING 

39 assert sf1.default_factory is dataclasses.MISSING 

40 

41 # Creation with default value 

42 sf2 = SerializableField(default=42) 

43 assert sf2.default == 42 

44 assert sf2.init is True 

45 assert sf2.repr is True 

46 assert sf2.compare is True 

47 

48 # Creation with default_factory 

49 sf3 = SerializableField(default_factory=list) 

50 assert sf3.default_factory == list # noqa: E721 

51 assert sf3.default is dataclasses.MISSING 

52 

53 # Creation with custom parameters 

54 sf4 = SerializableField( 

55 default=100, 

56 init=True, 

57 repr=False, 

58 hash=True, 

59 compare=False, 

60 serialize=True, 

61 ) 

62 assert sf4.default == 100 

63 assert sf4.init is True 

64 assert sf4.repr is False 

65 assert sf4.hash is True 

66 assert sf4.compare is False 

67 assert sf4.serialize is True 

68 

69 # Creation with serialization parameters 

70 def custom_serialize(x): 

71 return str(x) 

72 

73 def custom_deserialize(x): 

74 return int(x) 

75 

76 sf5 = SerializableField( 

77 serialization_fn=custom_serialize, 

78 deserialize_fn=custom_deserialize, 

79 assert_type=False, 

80 ) 

81 assert sf5.serialization_fn == custom_serialize 

82 assert sf5.deserialize_fn == custom_deserialize 

83 assert sf5.assert_type is False 

84 

85 

86def test_SerializableField_init_serialize_validation(): 

87 """Test that init=True and serialize=False raises ValueError.""" 

88 with pytest.raises(ValueError, match="Cannot have init=True and serialize=False"): 

89 SerializableField(init=True, serialize=False) 

90 

91 

92def test_SerializableField_loading_deserialize_conflict(): 

93 """Test that passing both loading_fn and deserialize_fn raises ValueError.""" 

94 

95 def dummy_fn(x): 

96 return x 

97 

98 with pytest.raises( 

99 ValueError, match="Cannot pass both loading_fn and deserialize_fn" 

100 ): 

101 SerializableField(loading_fn=dummy_fn, deserialize_fn=dummy_fn) 

102 

103 

104def test_SerializableField_doc(): 

105 """Test doc parameter handling across Python versions.""" 

106 sf = SerializableField(doc="Test documentation") 

107 assert sf.doc == "Test documentation" 

108 

109 

110# ============================================================================ 

111# Test from_Field() method 

112# ============================================================================ 

113 

114 

115def test_from_Field(): 

116 """Test converting a dataclasses.Field to SerializableField.""" 

117 # Create a standard dataclasses.Field 

118 dc_field: dataclasses.Field[int] = field( # type: ignore[assignment] 

119 default=42, # type: ignore[arg-type] 

120 init=True, 

121 repr=True, 

122 hash=None, 

123 compare=True, 

124 ) 

125 

126 # Convert to SerializableField 

127 sf = SerializableField.from_Field(dc_field) 

128 

129 # Verify all standard Field properties were copied 

130 assert sf.default == 42 

131 assert sf.init is True 

132 assert sf.repr is True 

133 assert sf.hash is None 

134 assert sf.compare is True 

135 

136 # Verify SerializableField-specific properties have defaults 

137 assert sf.serialize == sf.repr # serialize defaults to repr value 

138 assert sf.serialization_fn is None 

139 assert sf.loading_fn is None 

140 assert sf.deserialize_fn is None 

141 

142 # Test with default_factory and init=False to avoid init=True, serialize=False error 

143 # Note: field() is typed to return _T (the field value type), not Field[_T], to help 

144 # type checkers with normal @dataclass usage. Assigning to Field[T] is technically 

145 # correct at runtime but confuses mypy's overload resolution for default_factory. 

146 # Possibly related (closed) issue: https://github.com/python/mypy/issues/5738 

147 dc_field2: dataclasses.Field[list[Any]] = field( # type: ignore[assignment] 

148 default_factory=list, # type: ignore[arg-type] 

149 repr=True, 

150 init=True, # type: ignore[arg-type] 

151 ) 

152 sf2 = SerializableField.from_Field(dc_field2) 

153 assert sf2.default_factory == list # noqa: E721 

154 assert sf2.default is dataclasses.MISSING 

155 assert sf2.serialize is True # should match repr=True 

156 

157 

158# ============================================================================ 

159# Test serialization_fn and deserialize_fn 

160# ============================================================================ 

161 

162 

163def test_serialization_deserialize_fn(): 

164 """Test custom serialization and deserialization functions.""" 

165 

166 @serializable_dataclass 

167 class CustomSerialize(SerializableDataclass): 

168 # Serialize as uppercase, deserialize as lowercase 

169 value: str = serializable_field( 

170 serialization_fn=lambda x: x.upper(), 

171 deserialize_fn=lambda x: x.lower(), 

172 ) 

173 

174 # Test serialization 

175 instance = CustomSerialize(value="Hello") 

176 serialized = instance.serialize() 

177 assert serialized["value"] == "HELLO" 

178 

179 # Test deserialization 

180 loaded = CustomSerialize.load({"value": "WORLD"}) 

181 assert loaded.value == "world" 

182 

183 

184def test_serialization_fn_with_complex_type(): 

185 """Test serialization_fn with more complex transformations.""" 

186 

187 @serializable_dataclass 

188 class ComplexSerialize(SerializableDataclass): 

189 # Store a tuple as a list 

190 coords: Tuple[int, int] = serializable_field( 

191 default=(0, 0), 

192 serialization_fn=lambda x: list(x), 

193 deserialize_fn=lambda x: tuple(x), 

194 ) 

195 

196 instance = ComplexSerialize(coords=(3, 4)) 

197 serialized = instance.serialize() 

198 assert serialized["coords"] == [3, 4] # serialized as list 

199 

200 loaded = ComplexSerialize.load({"coords": [5, 6]}) 

201 assert loaded.coords == (5, 6) # loaded as tuple 

202 

203 

204# ============================================================================ 

205# Test loading_fn (takes full data dict) 

206# ============================================================================ 

207 

208 

209def test_loading_fn(): 

210 """Test loading_fn which takes the full data dict.""" 

211 

212 @serializable_dataclass 

213 class WithLoadingFn(SerializableDataclass): 

214 x: int 

215 y: int 

216 # computed field that depends on other fields 

217 sum_xy: int = serializable_field( 

218 init=False, 

219 serialize=False, 

220 default=0, 

221 ) 

222 

223 # Create instance 

224 instance = WithLoadingFn(x=3, y=4) 

225 instance.sum_xy = instance.x + instance.y 

226 assert instance.sum_xy == 7 

227 

228 

229def test_loading_fn_vs_deserialize_fn(): 

230 """Test the difference between loading_fn (dict) and deserialize_fn (value).""" 

231 

232 @serializable_dataclass 

233 class WithLoadingFn(SerializableDataclass): 

234 value: int = serializable_field( 

235 serialization_fn=lambda x: x * 2, 

236 loading_fn=lambda data: data["value"] // 2, # takes full dict 

237 ) 

238 

239 @serializable_dataclass 

240 class WithDeserializeFn(SerializableDataclass): 

241 value: int = serializable_field( 

242 serialization_fn=lambda x: x * 2, 

243 deserialize_fn=lambda x: x // 2, # takes just the value 

244 ) 

245 

246 # Both should behave the same in this case 

247 instance1 = WithLoadingFn(value=10) 

248 serialized1 = instance1.serialize() 

249 assert serialized1["value"] == 20 

250 

251 loaded1 = WithLoadingFn.load({"value": 20}) 

252 assert loaded1.value == 10 

253 

254 instance2 = WithDeserializeFn(value=10) 

255 serialized2 = instance2.serialize() 

256 assert serialized2["value"] == 20 

257 

258 loaded2 = WithDeserializeFn.load({"value": 20}) 

259 assert loaded2.value == 10 

260 

261 

262# ============================================================================ 

263# Test field validation: assert_type and custom_typecheck_fn 

264# ============================================================================ 

265 

266 

267def test_field_validation_assert_type(): 

268 """Test assert_type parameter for type validation.""" 

269 

270 @serializable_dataclass 

271 class StrictType(SerializableDataclass): 

272 value: int = serializable_field(assert_type=True) 

273 

274 @serializable_dataclass 

275 class LooseType(SerializableDataclass): 

276 value: int = serializable_field(assert_type=False) 

277 

278 # Strict type checking should warn with wrong type (using WARN mode by default) 

279 with pytest.warns(UserWarning, match="Type mismatch"): 

280 instance = StrictType.load({"value": "not an int"}) 

281 assert instance.value == "not an int" 

282 

283 # Loose type checking should allow wrong type without warning 

284 instance2 = LooseType.load({"value": "not an int"}) 

285 assert instance2.value == "not an int" 

286 

287 

288def test_field_validation_custom_typecheck_fn(): 

289 """Test custom_typecheck_fn for custom type validation.""" 

290 

291 def is_positive(value: Any) -> bool: 

292 """Check if value is a positive number.""" 

293 return isinstance(value, (int, float)) and value > 0 

294 

295 @serializable_dataclass 

296 class PositiveNumber(SerializableDataclass): 

297 value: int = serializable_field( 

298 custom_typecheck_fn=lambda t: True # Accept any type 

299 ) 

300 

301 # This should work because custom_typecheck_fn returns True 

302 instance = PositiveNumber(value=42) 

303 assert instance.value == 42 

304 

305 

306# ============================================================================ 

307# Test serializable_field() function 

308# ============================================================================ 

309 

310 

311def test_serializable_field_function(): 

312 """Test the serializable_field() function wrapper.""" 

313 # Test basic usage 

314 f1 = serializable_field() 

315 assert isinstance(f1, SerializableField) 

316 assert f1.serialize is True 

317 

318 # Test with default 

319 f2: SerializableField = serializable_field(default=100) # type: ignore[assignment] 

320 assert f2.default == 100 

321 

322 # Test with default_factory 

323 f3: SerializableField = serializable_field(default_factory=list) # type: ignore[assignment] 

324 assert f3.default_factory == list # noqa: E721 

325 

326 # Test with all parameters 

327 f4: SerializableField = serializable_field( # type: ignore[assignment] 

328 default=42, 

329 init=True, 

330 repr=False, 

331 hash=True, 

332 compare=False, 

333 serialize=True, 

334 serialization_fn=str, 

335 deserialize_fn=int, 

336 assert_type=False, 

337 ) 

338 assert f4.default == 42 

339 assert f4.repr is False 

340 assert f4.hash is True 

341 assert f4.serialization_fn == str # noqa: E721 

342 assert f4.deserialize_fn == int # noqa: E721 

343 

344 

345def test_serializable_field_no_positional_args(): 

346 """Test that serializable_field doesn't accept positional arguments.""" 

347 with pytest.raises(AssertionError, match="unexpected positional arguments"): 

348 serializable_field("invalid") # type: ignore 

349 

350 

351def test_serializable_field_description_deprecated(): 

352 """Test that 'description' parameter is deprecated in favor of 'doc'.""" 

353 import warnings 

354 

355 # Using description should raise DeprecationWarning 

356 with warnings.catch_warnings(record=True) as w: 

357 warnings.simplefilter("always") 

358 f = serializable_field(description="Test description") 

359 # Check that a deprecation warning was issued 

360 assert len(w) == 1 

361 assert issubclass(w[0].category, DeprecationWarning) 

362 assert "`description` is deprecated" in str(w[0].message) 

363 # Verify doc was set 

364 assert f.doc == "Test description" 

365 

366 # Using both doc and description should raise ValueError 

367 with pytest.raises(ValueError, match="cannot pass both"): 

368 serializable_field(doc="Doc", description="Description") 

369 

370 

371# ============================================================================ 

372# Integration tests with SerializableDataclass 

373# ============================================================================ 

374 

375 

376def test_serializable_field_integration(): 

377 """Test SerializableField integration with SerializableDataclass.""" 

378 

379 @serializable_dataclass 

380 class IntegrationTest(SerializableDataclass): 

381 # Regular field 

382 normal: str 

383 

384 # Field with custom serialization (no default, so must come before fields with defaults) 

385 custom: str = serializable_field( 

386 serialization_fn=lambda x: x.upper(), 

387 deserialize_fn=lambda x: x.lower(), 

388 ) 

389 

390 # Field with default 

391 with_default: int = serializable_field(default=42) 

392 

393 # Field with default_factory 

394 with_factory: list = serializable_field(default_factory=list) 

395 

396 # Non-serialized field 

397 internal: int = serializable_field(init=False, serialize=False, default=0) 

398 

399 # Create instance 

400 instance = IntegrationTest( 

401 normal="test", 

402 custom="hello", 

403 with_default=100, 

404 with_factory=[1, 2, 3], 

405 ) 

406 instance.internal = 999 

407 

408 # Serialize 

409 serialized = instance.serialize() 

410 assert serialized["normal"] == "test" 

411 assert serialized["with_default"] == 100 

412 assert serialized["with_factory"] == [1, 2, 3] 

413 assert serialized["custom"] == "HELLO" # uppercase 

414 assert "internal" not in serialized # not serialized 

415 

416 # Load 

417 loaded = IntegrationTest.load( 

418 { 

419 "normal": "loaded", 

420 "custom": "WORLD", 

421 "with_default": 200, 

422 "with_factory": [4, 5], 

423 } 

424 ) 

425 assert loaded.normal == "loaded" 

426 assert loaded.with_default == 200 

427 assert loaded.with_factory == [4, 5] 

428 assert loaded.custom == "world" # lowercase 

429 assert loaded.internal == 0 # default value