Coverage for tests/unit/test_dictmagic.py: 100%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from __future__ import annotations 

2from typing import Dict 

3 

4import pytest 

5 

6from muutils.dictmagic import ( 

7 condense_nested_dicts, 

8 condense_nested_dicts_matching_values, 

9 condense_tensor_dict, 

10 dotlist_to_nested_dict, 

11 is_numeric_consecutive, 

12 kwargs_to_nested_dict, 

13 nested_dict_to_dotlist, 

14 tuple_dims_replace, 

15 update_with_nested_dict, 

16) 

17from muutils.json_serialize import SerializableDataclass, serializable_dataclass 

18 

19 

20def test_dotlist_to_nested_dict(): 

21 # Positive case 

22 assert dotlist_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == { 

23 "a": {"b": {"c": 1, "d": 2}, "e": 3} 

24 } 

25 

26 # Negative case 

27 with pytest.raises(TypeError): 

28 dotlist_to_nested_dict({1: 1}) # type: ignore[dict-item] 

29 

30 # Test with different separator 

31 assert dotlist_to_nested_dict({"a/b/c": 1, "a/b/d": 2, "a/e": 3}, sep="/") == { 

32 "a": {"b": {"c": 1, "d": 2}, "e": 3} 

33 } 

34 

35 

36def test_update_with_nested_dict(): 

37 # Positive case 

38 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"a": {"b": 2}}) == { 

39 "a": {"b": 2}, 

40 "c": -1, 

41 } 

42 

43 # Case where the key is not present in original dict 

44 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"d": 3}) == { 

45 "a": {"b": 1}, 

46 "c": -1, 

47 "d": 3, 

48 } 

49 

50 # Case where a nested value is overridden 

51 assert update_with_nested_dict( 

52 {"a": {"b": 1, "d": 3}, "c": -1}, {"a": {"b": 2}} 

53 ) == {"a": {"b": 2, "d": 3}, "c": -1} 

54 

55 # Case where the dict we are trying to update does not exist 

56 assert update_with_nested_dict({"a": 1}, {"b": {"c": 2}}) == {"a": 1, "b": {"c": 2}} 

57 

58 

59def test_kwargs_to_nested_dict(): 

60 # Positive case 

61 assert kwargs_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == { 

62 "a": {"b": {"c": 1, "d": 2}, "e": 3} 

63 } 

64 

65 # Case where strip_prefix is not None 

66 assert kwargs_to_nested_dict( 

67 {"prefix.a.b.c": 1, "prefix.a.b.d": 2, "prefix.a.e": 3}, strip_prefix="prefix." 

68 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

69 

70 # Negative case 

71 with pytest.raises(ValueError): 

72 kwargs_to_nested_dict( 

73 {"a.b.c": 1, "a.b.d": 2, "a.e": 3}, 

74 strip_prefix="prefix.", 

75 when_unknown_prefix="raise", 

76 ) 

77 

78 # Case where -- and - prefix 

79 assert kwargs_to_nested_dict( 

80 {"--a.b.c": 1, "--a.b.d": 2, "a.e": 3}, 

81 strip_prefix="--", 

82 when_unknown_prefix="ignore", 

83 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

84 

85 # Case where -- and - prefix with warning 

86 with pytest.warns(UserWarning): 

87 kwargs_to_nested_dict( 

88 {"--a.b.c": 1, "-a.b.d": 2, "a.e": 3}, 

89 strip_prefix="-", 

90 when_unknown_prefix="warn", 

91 ) 

92 

93 

94def test_kwargs_to_nested_dict_transform_key(): 

95 # Case where transform_key is not None, changing dashes to underscores 

96 assert kwargs_to_nested_dict( 

97 {"a-b-c": 1, "a-b-d": 2, "a-e": 3}, transform_key=lambda x: x.replace("-", "_") 

98 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3} 

99 

100 # Case where strip_prefix and transform_key are both used 

101 assert kwargs_to_nested_dict( 

102 {"prefix.a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3}, 

103 strip_prefix="prefix.", 

104 transform_key=lambda x: x.replace("-", "_"), 

105 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3} 

106 

107 # Case where strip_prefix, transform_key and when_unknown_prefix='raise' are all used 

108 with pytest.raises(ValueError): 

109 kwargs_to_nested_dict( 

110 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3}, 

111 strip_prefix="prefix.", 

112 transform_key=lambda x: x.replace("-", "_"), 

113 when_unknown_prefix="raise", 

114 ) 

115 

116 # Case where strip_prefix, transform_key and when_unknown_prefix='warn' are all used 

117 with pytest.warns(UserWarning): 

118 assert kwargs_to_nested_dict( 

119 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3}, 

120 strip_prefix="prefix.", 

121 transform_key=lambda x: x.replace("-", "_"), 

122 when_unknown_prefix="warn", 

123 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3} 

124 

125 

126@serializable_dataclass 

127class ChildData(SerializableDataclass): 

128 x: int 

129 y: int 

130 

131 

132@serializable_dataclass 

133class ParentData(SerializableDataclass): 

134 a: int 

135 b: ChildData 

136 

137 

138def test_update_from_nested_dict(): 

139 parent = ParentData(a=1, b=ChildData(x=2, y=3)) 

140 update_data = {"a": 5, "b": {"x": 6}} 

141 parent.update_from_nested_dict(update_data) 

142 

143 assert parent.a == 5 

144 assert parent.b.x == 6 

145 assert parent.b.y == 3 

146 

147 update_data2 = {"b": {"y": 7}} 

148 parent.update_from_nested_dict(update_data2) 

149 

150 assert parent.a == 5 

151 assert parent.b.x == 6 

152 assert parent.b.y == 7 

153 

154 

155def test_update_from_dotlists(): 

156 parent = ParentData(a=1, b=ChildData(x=2, y=3)) 

157 update_data = {"a": 5, "b.x": 6} 

158 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data)) 

159 

160 assert parent.a == 5 

161 assert parent.b.x == 6 

162 assert parent.b.y == 3 

163 

164 update_data2 = {"b.y": 7} 

165 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data2)) 

166 

167 assert parent.a == 5 

168 assert parent.b.x == 6 

169 assert parent.b.y == 7 

170 

171 

172# Tests for is_numeric_consecutive 

173@pytest.mark.parametrize( 

174 "test_input,expected", 

175 [ 

176 (["1", "2", "3"], True), 

177 (["1", "3", "2"], True), 

178 (["1", "4", "2"], False), 

179 ([], False), 

180 (["a", "2", "3"], False), 

181 ], 

182) 

183def test_is_numeric_consecutive(test_input, expected): 

184 assert is_numeric_consecutive(test_input) == expected 

185 

186 

187# Tests for condense_nested_dicts 

188def test_condense_nested_dicts_single_level(): 

189 data = {"1": "a", "2": "a", "3": "b"} 

190 expected = {"[1-2]": "a", "3": "b"} 

191 assert condense_nested_dicts(data) == expected 

192 

193 

194def test_condense_nested_dicts_nested(): 

195 data = {"1": {"1": "a", "2": "a"}, "2": "b"} 

196 expected = {"1": {"[1-2]": "a"}, "2": "b"} 

197 assert condense_nested_dicts(data) == expected 

198 

199 

200def test_condense_nested_dicts_non_numeric(): 

201 data = {"a": "a", "b": "a", "c": "b"} 

202 assert condense_nested_dicts(data, condense_matching_values=False) == data 

203 assert condense_nested_dicts(data, condense_matching_values=True) == { 

204 "[a, b]": "a", 

205 "c": "b", 

206 } 

207 

208 

209def test_condense_nested_dicts_mixed_keys(): 

210 data = {"1": "a", "2": "a", "a": "b"} 

211 assert condense_nested_dicts(data) == {"[1, 2]": "a", "a": "b"} 

212 

213 

214# Mocking a Tensor-like object for use in tests 

215class MockTensor: 

216 def __init__(self, shape): 

217 self.shape = shape 

218 

219 

220# Test cases for `tuple_dims_replace` 

221@pytest.mark.parametrize( 

222 "input_tuple,dims_names_map,expected", 

223 [ 

224 ((1, 2, 3), {1: "A", 2: "B"}, ("A", "B", 3)), 

225 ((4, 5, 6), {}, (4, 5, 6)), 

226 ((7, 8), None, (7, 8)), 

227 ((1, 2, 3), {3: "C"}, (1, 2, "C")), 

228 ], 

229) 

230def test_tuple_dims_replace(input_tuple, dims_names_map, expected): 

231 assert tuple_dims_replace(input_tuple, dims_names_map) == expected 

232 

233 

234@pytest.fixture 

235def tensor_data(): 

236 # Mock tensor data simulating different shapes 

237 return { 

238 "tensor1": MockTensor((10, 256, 256)), 

239 "tensor2": MockTensor((10, 256, 256)), 

240 "tensor3": MockTensor((10, 512, 256)), 

241 } 

242 

243 

244def test_condense_tensor_dict_basic(tensor_data): 

245 assert condense_tensor_dict( 

246 tensor_data, 

247 drop_batch_dims=1, 

248 condense_matching_values=False, 

249 ) == { 

250 "tensor1": "(256, 256)", 

251 "tensor2": "(256, 256)", 

252 "tensor3": "(512, 256)", 

253 } 

254 

255 assert condense_tensor_dict( 

256 tensor_data, 

257 drop_batch_dims=1, 

258 condense_matching_values=True, 

259 ) == { 

260 "[tensor1, tensor2]": "(256, 256)", 

261 "tensor3": "(512, 256)", 

262 } 

263 

264 

265def test_condense_tensor_dict_shapes_convert(tensor_data): 

266 # Returning the actual shape tuple 

267 shapes_convert = lambda x: x # noqa: E731 

268 assert condense_tensor_dict( 

269 tensor_data, 

270 shapes_convert=shapes_convert, 

271 drop_batch_dims=1, 

272 condense_matching_values=False, 

273 ) == { 

274 "tensor1": (256, 256), 

275 "tensor2": (256, 256), 

276 "tensor3": (512, 256), 

277 } 

278 

279 assert condense_tensor_dict( 

280 tensor_data, 

281 shapes_convert=shapes_convert, 

282 drop_batch_dims=1, 

283 condense_matching_values=True, 

284 ) == { 

285 "[tensor1, tensor2]": (256, 256), 

286 "tensor3": (512, 256), 

287 } 

288 

289 

290def test_condense_tensor_dict_named_dims(tensor_data): 

291 assert condense_tensor_dict( 

292 tensor_data, 

293 dims_names_map={10: "B", 256: "A", 512: "C"}, 

294 condense_matching_values=False, 

295 ) == { 

296 "tensor1": "(B, A, A)", 

297 "tensor2": "(B, A, A)", 

298 "tensor3": "(B, C, A)", 

299 } 

300 

301 assert condense_tensor_dict( 

302 tensor_data, 

303 dims_names_map={10: "B", 256: "A", 512: "C"}, 

304 condense_matching_values=True, 

305 ) == {"[tensor1, tensor2]": "(B, A, A)", "tensor3": "(B, C, A)"} 

306 

307 

308@pytest.mark.parametrize( 

309 "input_data,expected,fallback_mapping", 

310 [ 

311 # Test 1: Simple dictionary with no identical values 

312 ({"a": 1, "b": 2}, {"a": 1, "b": 2}, None), 

313 # Test 2: Dictionary with identical values 

314 ({"a": 1, "b": 1, "c": 2}, {"[a, b]": 1, "c": 2}, None), 

315 # Test 3: Nested dictionary with identical values 

316 ({"a": {"x": 1, "y": 1}, "b": 2}, {"a": {"[x, y]": 1}, "b": 2}, None), 

317 # Test 4: Nested dictionaries with and without identical values 

318 ( 

319 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1}, 

320 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1}, 

321 None, 

322 ), 

323 # Test 5: Dictionary with unhashable values and no fallback mapping 

324 # This case is expected to fail without a fallback mapping, hence not included when using str as fallback 

325 # Test 6: Dictionary with unhashable values and a fallback mapping as str 

326 ( 

327 {"a": [1, 2], "b": [1, 2], "c": "test"}, 

328 {"[a, b]": "[1, 2]", "c": "test"}, 

329 str, 

330 ), 

331 ], 

332) 

333def test_condense_nested_dicts_matching_values(input_data, expected, fallback_mapping): 

334 if fallback_mapping is not None: 

335 result = condense_nested_dicts_matching_values(input_data, fallback_mapping) 

336 else: 

337 result = condense_nested_dicts_matching_values(input_data) 

338 assert result == expected, f"Expected {expected}, got {result}" 

339 

340 

341# "ndtd" = `nested_dict_to_dotlist` 

342def test_nested_dict_to_dotlist_basic(): 

343 nested_dict = {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

344 expected_dotlist = {"a.b.c": 1, "a.b.d": 2, "a.e": 3} 

345 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist 

346 

347 

348def test_nested_dict_to_dotlist_empty(): 

349 nested_dict: dict = {} 

350 expected_dotlist: dict = {} 

351 result = nested_dict_to_dotlist(nested_dict) 

352 assert result == expected_dotlist 

353 

354 

355def test_nested_dict_to_dotlist_single_level(): 

356 nested_dict: Dict[str, int] = {"a": 1, "b": 2, "c": 3} 

357 expected_dotlist: Dict[str, int] = {"a": 1, "b": 2, "c": 3} 

358 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist 

359 

360 

361def test_nested_dict_to_dotlist_with_list(): 

362 nested_dict: dict = {"a": [1, 2, {"b": 3}], "c": 4} 

363 expected_dotlist: Dict[str, int] = {"a.0": 1, "a.1": 2, "a.2.b": 3, "c": 4} 

364 assert nested_dict_to_dotlist(nested_dict, allow_lists=True) == expected_dotlist 

365 

366 

367def test_nested_dict_to_dotlist_nested_empty(): 

368 nested_dict: dict = {"a": {"b": {}}} 

369 expected_dotlist: dict = {"a.b": {}} 

370 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist 

371 

372 

373def test_round_trip_conversion(): 

374 original: dict = {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

375 dotlist = nested_dict_to_dotlist(original) 

376 result = dotlist_to_nested_dict(dotlist) 

377 assert result == original