Coverage for tests/unit/test_dictmagic.py: 100%
131 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1from __future__ import annotations
2from typing import Dict
4import pytest
6from muutils.dictmagic import (
7 condense_nested_dicts,
8 condense_nested_dicts_matching_values,
9 condense_tensor_dict,
10 dotlist_to_nested_dict,
11 is_numeric_consecutive,
12 kwargs_to_nested_dict,
13 nested_dict_to_dotlist,
14 tuple_dims_replace,
15 update_with_nested_dict,
16)
17from muutils.json_serialize import SerializableDataclass, serializable_dataclass
20def test_dotlist_to_nested_dict():
21 # Positive case
22 assert dotlist_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == {
23 "a": {"b": {"c": 1, "d": 2}, "e": 3}
24 }
26 # Negative case
27 with pytest.raises(TypeError):
28 dotlist_to_nested_dict({1: 1}) # type: ignore[dict-item]
30 # Test with different separator
31 assert dotlist_to_nested_dict({"a/b/c": 1, "a/b/d": 2, "a/e": 3}, sep="/") == {
32 "a": {"b": {"c": 1, "d": 2}, "e": 3}
33 }
36def test_update_with_nested_dict():
37 # Positive case
38 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"a": {"b": 2}}) == {
39 "a": {"b": 2},
40 "c": -1,
41 }
43 # Case where the key is not present in original dict
44 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"d": 3}) == {
45 "a": {"b": 1},
46 "c": -1,
47 "d": 3,
48 }
50 # Case where a nested value is overridden
51 assert update_with_nested_dict(
52 {"a": {"b": 1, "d": 3}, "c": -1}, {"a": {"b": 2}}
53 ) == {"a": {"b": 2, "d": 3}, "c": -1}
55 # Case where the dict we are trying to update does not exist
56 assert update_with_nested_dict({"a": 1}, {"b": {"c": 2}}) == {"a": 1, "b": {"c": 2}}
59def test_kwargs_to_nested_dict():
60 # Positive case
61 assert kwargs_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == {
62 "a": {"b": {"c": 1, "d": 2}, "e": 3}
63 }
65 # Case where strip_prefix is not None
66 assert kwargs_to_nested_dict(
67 {"prefix.a.b.c": 1, "prefix.a.b.d": 2, "prefix.a.e": 3}, strip_prefix="prefix."
68 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
70 # Negative case
71 with pytest.raises(ValueError):
72 kwargs_to_nested_dict(
73 {"a.b.c": 1, "a.b.d": 2, "a.e": 3},
74 strip_prefix="prefix.",
75 when_unknown_prefix="raise",
76 )
78 # Case where -- and - prefix
79 assert kwargs_to_nested_dict(
80 {"--a.b.c": 1, "--a.b.d": 2, "a.e": 3},
81 strip_prefix="--",
82 when_unknown_prefix="ignore",
83 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
85 # Case where -- and - prefix with warning
86 with pytest.warns(UserWarning):
87 kwargs_to_nested_dict(
88 {"--a.b.c": 1, "-a.b.d": 2, "a.e": 3},
89 strip_prefix="-",
90 when_unknown_prefix="warn",
91 )
94def test_kwargs_to_nested_dict_transform_key():
95 # Case where transform_key is not None, changing dashes to underscores
96 assert kwargs_to_nested_dict(
97 {"a-b-c": 1, "a-b-d": 2, "a-e": 3}, transform_key=lambda x: x.replace("-", "_")
98 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3}
100 # Case where strip_prefix and transform_key are both used
101 assert kwargs_to_nested_dict(
102 {"prefix.a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3},
103 strip_prefix="prefix.",
104 transform_key=lambda x: x.replace("-", "_"),
105 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3}
107 # Case where strip_prefix, transform_key and when_unknown_prefix='raise' are all used
108 with pytest.raises(ValueError):
109 kwargs_to_nested_dict(
110 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3},
111 strip_prefix="prefix.",
112 transform_key=lambda x: x.replace("-", "_"),
113 when_unknown_prefix="raise",
114 )
116 # Case where strip_prefix, transform_key and when_unknown_prefix='warn' are all used
117 with pytest.warns(UserWarning):
118 assert kwargs_to_nested_dict(
119 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3},
120 strip_prefix="prefix.",
121 transform_key=lambda x: x.replace("-", "_"),
122 when_unknown_prefix="warn",
123 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3}
126@serializable_dataclass
127class ChildData(SerializableDataclass):
128 x: int
129 y: int
132@serializable_dataclass
133class ParentData(SerializableDataclass):
134 a: int
135 b: ChildData
138def test_update_from_nested_dict():
139 parent = ParentData(a=1, b=ChildData(x=2, y=3))
140 update_data = {"a": 5, "b": {"x": 6}}
141 parent.update_from_nested_dict(update_data)
143 assert parent.a == 5
144 assert parent.b.x == 6
145 assert parent.b.y == 3
147 update_data2 = {"b": {"y": 7}}
148 parent.update_from_nested_dict(update_data2)
150 assert parent.a == 5
151 assert parent.b.x == 6
152 assert parent.b.y == 7
155def test_update_from_dotlists():
156 parent = ParentData(a=1, b=ChildData(x=2, y=3))
157 update_data = {"a": 5, "b.x": 6}
158 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data))
160 assert parent.a == 5
161 assert parent.b.x == 6
162 assert parent.b.y == 3
164 update_data2 = {"b.y": 7}
165 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data2))
167 assert parent.a == 5
168 assert parent.b.x == 6
169 assert parent.b.y == 7
172# Tests for is_numeric_consecutive
173@pytest.mark.parametrize(
174 "test_input,expected",
175 [
176 (["1", "2", "3"], True),
177 (["1", "3", "2"], True),
178 (["1", "4", "2"], False),
179 ([], False),
180 (["a", "2", "3"], False),
181 ],
182)
183def test_is_numeric_consecutive(test_input, expected):
184 assert is_numeric_consecutive(test_input) == expected
187# Tests for condense_nested_dicts
188def test_condense_nested_dicts_single_level():
189 data = {"1": "a", "2": "a", "3": "b"}
190 expected = {"[1-2]": "a", "3": "b"}
191 assert condense_nested_dicts(data) == expected
194def test_condense_nested_dicts_nested():
195 data = {"1": {"1": "a", "2": "a"}, "2": "b"}
196 expected = {"1": {"[1-2]": "a"}, "2": "b"}
197 assert condense_nested_dicts(data) == expected
200def test_condense_nested_dicts_non_numeric():
201 data = {"a": "a", "b": "a", "c": "b"}
202 assert condense_nested_dicts(data, condense_matching_values=False) == data
203 assert condense_nested_dicts(data, condense_matching_values=True) == {
204 "[a, b]": "a",
205 "c": "b",
206 }
209def test_condense_nested_dicts_mixed_keys():
210 data = {"1": "a", "2": "a", "a": "b"}
211 assert condense_nested_dicts(data) == {"[1, 2]": "a", "a": "b"}
214# Mocking a Tensor-like object for use in tests
215class MockTensor:
216 def __init__(self, shape):
217 self.shape = shape
220# Test cases for `tuple_dims_replace`
221@pytest.mark.parametrize(
222 "input_tuple,dims_names_map,expected",
223 [
224 ((1, 2, 3), {1: "A", 2: "B"}, ("A", "B", 3)),
225 ((4, 5, 6), {}, (4, 5, 6)),
226 ((7, 8), None, (7, 8)),
227 ((1, 2, 3), {3: "C"}, (1, 2, "C")),
228 ],
229)
230def test_tuple_dims_replace(input_tuple, dims_names_map, expected):
231 assert tuple_dims_replace(input_tuple, dims_names_map) == expected
234@pytest.fixture
235def tensor_data():
236 # Mock tensor data simulating different shapes
237 return {
238 "tensor1": MockTensor((10, 256, 256)),
239 "tensor2": MockTensor((10, 256, 256)),
240 "tensor3": MockTensor((10, 512, 256)),
241 }
244def test_condense_tensor_dict_basic(tensor_data):
245 assert condense_tensor_dict(
246 tensor_data,
247 drop_batch_dims=1,
248 condense_matching_values=False,
249 ) == {
250 "tensor1": "(256, 256)",
251 "tensor2": "(256, 256)",
252 "tensor3": "(512, 256)",
253 }
255 assert condense_tensor_dict(
256 tensor_data,
257 drop_batch_dims=1,
258 condense_matching_values=True,
259 ) == {
260 "[tensor1, tensor2]": "(256, 256)",
261 "tensor3": "(512, 256)",
262 }
265def test_condense_tensor_dict_shapes_convert(tensor_data):
266 # Returning the actual shape tuple
267 shapes_convert = lambda x: x # noqa: E731
268 assert condense_tensor_dict(
269 tensor_data,
270 shapes_convert=shapes_convert,
271 drop_batch_dims=1,
272 condense_matching_values=False,
273 ) == {
274 "tensor1": (256, 256),
275 "tensor2": (256, 256),
276 "tensor3": (512, 256),
277 }
279 assert condense_tensor_dict(
280 tensor_data,
281 shapes_convert=shapes_convert,
282 drop_batch_dims=1,
283 condense_matching_values=True,
284 ) == {
285 "[tensor1, tensor2]": (256, 256),
286 "tensor3": (512, 256),
287 }
290def test_condense_tensor_dict_named_dims(tensor_data):
291 assert condense_tensor_dict(
292 tensor_data,
293 dims_names_map={10: "B", 256: "A", 512: "C"},
294 condense_matching_values=False,
295 ) == {
296 "tensor1": "(B, A, A)",
297 "tensor2": "(B, A, A)",
298 "tensor3": "(B, C, A)",
299 }
301 assert condense_tensor_dict(
302 tensor_data,
303 dims_names_map={10: "B", 256: "A", 512: "C"},
304 condense_matching_values=True,
305 ) == {"[tensor1, tensor2]": "(B, A, A)", "tensor3": "(B, C, A)"}
308@pytest.mark.parametrize(
309 "input_data,expected,fallback_mapping",
310 [
311 # Test 1: Simple dictionary with no identical values
312 ({"a": 1, "b": 2}, {"a": 1, "b": 2}, None),
313 # Test 2: Dictionary with identical values
314 ({"a": 1, "b": 1, "c": 2}, {"[a, b]": 1, "c": 2}, None),
315 # Test 3: Nested dictionary with identical values
316 ({"a": {"x": 1, "y": 1}, "b": 2}, {"a": {"[x, y]": 1}, "b": 2}, None),
317 # Test 4: Nested dictionaries with and without identical values
318 (
319 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1},
320 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1},
321 None,
322 ),
323 # Test 5: Dictionary with unhashable values and no fallback mapping
324 # This case is expected to fail without a fallback mapping, hence not included when using str as fallback
325 # Test 6: Dictionary with unhashable values and a fallback mapping as str
326 (
327 {"a": [1, 2], "b": [1, 2], "c": "test"},
328 {"[a, b]": "[1, 2]", "c": "test"},
329 str,
330 ),
331 ],
332)
333def test_condense_nested_dicts_matching_values(input_data, expected, fallback_mapping):
334 if fallback_mapping is not None:
335 result = condense_nested_dicts_matching_values(input_data, fallback_mapping)
336 else:
337 result = condense_nested_dicts_matching_values(input_data)
338 assert result == expected, f"Expected {expected}, got {result}"
341# "ndtd" = `nested_dict_to_dotlist`
342def test_nested_dict_to_dotlist_basic():
343 nested_dict = {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
344 expected_dotlist = {"a.b.c": 1, "a.b.d": 2, "a.e": 3}
345 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist
348def test_nested_dict_to_dotlist_empty():
349 nested_dict: dict = {}
350 expected_dotlist: dict = {}
351 result = nested_dict_to_dotlist(nested_dict)
352 assert result == expected_dotlist
355def test_nested_dict_to_dotlist_single_level():
356 nested_dict: Dict[str, int] = {"a": 1, "b": 2, "c": 3}
357 expected_dotlist: Dict[str, int] = {"a": 1, "b": 2, "c": 3}
358 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist
361def test_nested_dict_to_dotlist_with_list():
362 nested_dict: dict = {"a": [1, 2, {"b": 3}], "c": 4}
363 expected_dotlist: Dict[str, int] = {"a.0": 1, "a.1": 2, "a.2.b": 3, "c": 4}
364 assert nested_dict_to_dotlist(nested_dict, allow_lists=True) == expected_dotlist
367def test_nested_dict_to_dotlist_nested_empty():
368 nested_dict: dict = {"a": {"b": {}}}
369 expected_dotlist: dict = {"a.b": {}}
370 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist
373def test_round_trip_conversion():
374 original: dict = {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
375 dotlist = nested_dict_to_dotlist(original)
376 result = dotlist_to_nested_dict(dotlist)
377 assert result == original