Coverage for muutils / dictmagic.py: 86%
160 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1"""making working with dictionaries easier
3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument
4- various methods for working wit dotlist-nested dicts, converting to and from them
5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges
6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes
7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict
8"""
10from __future__ import annotations
12import typing
13import warnings
14from collections import defaultdict
15from typing import (
16 Any,
17 Callable,
18 Generic,
19 Hashable,
20 Iterable,
21 Literal,
22 Optional,
23 TypeVar,
24 Union,
25)
27from muutils.errormode import ErrorMode
29_KT = TypeVar("_KT")
30_VT = TypeVar("_VT")
33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]):
34 """like a defaultdict, but default_factory is passed the key as an argument"""
36 def __init__(
37 self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any
38 ) -> None:
39 if args:
40 raise TypeError(
41 f"DefaulterDict does not support positional arguments: *args = {args}"
42 )
43 super().__init__(**kwargs)
44 self.default_factory: Callable[[_KT], _VT] = default_factory
46 def __getitem__(self, k: _KT) -> _VT:
47 if k in self:
48 return dict.__getitem__(self, k)
49 else:
50 v: _VT = self.default_factory(k)
51 dict.__setitem__(self, k, v)
52 return v
55def _recursive_defaultdict_ctor() -> defaultdict:
56 return defaultdict(_recursive_defaultdict_ctor)
59def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict:
60 """Convert a defaultdict or DefaulterDict to a normal dict, recursively"""
61 return {
62 key: (
63 defaultdict_to_dict_recursive(value)
64 if isinstance(value, (defaultdict, DefaulterDict))
65 else value
66 )
67 for key, value in dd.items()
68 }
71def dotlist_to_nested_dict(
72 dot_dict: typing.Dict[str, Any], sep: str = "."
73) -> typing.Dict[str, Any]:
74 """Convert a dict with dot-separated keys to a nested dict
76 Example:
78 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
79 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
80 """
81 nested_dict: defaultdict = _recursive_defaultdict_ctor()
82 for key, value in dot_dict.items():
83 if not isinstance(key, str):
84 raise TypeError(f"key must be a string, got {type(key)}")
85 keys: list[str] = key.split(sep)
86 current: defaultdict = nested_dict
87 # iterate over the keys except the last one
88 for sub_key in keys[:-1]:
89 current = current[sub_key]
90 current[keys[-1]] = value
91 return defaultdict_to_dict_recursive(nested_dict)
94def nested_dict_to_dotlist(
95 nested_dict: typing.Dict[str, Any],
96 sep: str = ".",
97 allow_lists: bool = False,
98) -> dict[str, Any]:
99 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]:
100 items: dict = dict()
102 new_key: str
103 if isinstance(current, dict):
104 # dict case
105 if not current and parent_key:
106 items[parent_key] = current
107 else:
108 for k, v in current.items():
109 new_key = f"{parent_key}{sep}{k}" if parent_key else k
110 items.update(_recurse(v, new_key))
112 elif allow_lists and isinstance(current, list):
113 # list case
114 for i, item in enumerate(current):
115 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
116 items.update(_recurse(item, new_key))
118 else:
119 # anything else (write value)
120 items[parent_key] = current
122 return items
124 return _recurse(nested_dict)
127def update_with_nested_dict(
128 original: dict[str, Any],
129 update: dict[str, Any],
130) -> dict[str, Any]:
131 """Update a dict with a nested dict
133 Example:
134 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
135 {'a': {'b': 2}, 'c': -1}
137 # Arguments
138 - `original: dict[str, Any]`
139 the dict to update (will be modified in-place)
140 - `update: dict[str, Any]`
141 the dict to update with
143 # Returns
144 - `dict`
145 the updated dict
146 """
147 for key, value in update.items():
148 if key in original:
149 if isinstance(original[key], dict) and isinstance(value, dict):
150 update_with_nested_dict(original[key], value)
151 else:
152 original[key] = value
153 else:
154 original[key] = value
156 return original
159def kwargs_to_nested_dict(
160 kwargs_dict: dict[str, Any],
161 sep: str = ".",
162 strip_prefix: Optional[str] = None,
163 when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN,
164 transform_key: Optional[Callable[[str], str]] = None,
165) -> dict[str, Any]:
166 """given kwargs from fire, convert them to a nested dict
168 if strip_prefix is not None, then all keys must start with the prefix. by default,
169 will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
170 `when_unknown_prefix: ErrorMode`
172 Example:
173 ```python
174 def main(**kwargs):
175 print(kwargs_to_nested_dict(kwargs))
176 fire.Fire(main)
177 ```
178 running the above script will give:
179 ```bash
180 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
181 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
182 ```
184 # Arguments
185 - `kwargs_dict: dict[str, Any]`
186 the kwargs dict to convert
187 - `sep: str = "."`
188 the separator to use for nested keys
189 - `strip_prefix: Optional[str] = None`
190 if not None, then all keys must start with this prefix
191 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN`
192 what to do when an unknown prefix is found
193 - `transform_key: Callable[[str], str] | None = None`
194 a function to apply to each key before adding it to the dict (applied after stripping the prefix)
195 """
196 when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix)
197 filtered_kwargs: dict[str, Any] = dict()
198 for key, value in kwargs_dict.items():
199 if strip_prefix is not None:
200 if not key.startswith(strip_prefix):
201 when_unknown_prefix_.process(
202 f"key '{key}' does not start with '{strip_prefix}'",
203 except_cls=ValueError,
204 )
205 else:
206 key = key[len(strip_prefix) :]
208 if transform_key is not None:
209 key = transform_key(key)
211 filtered_kwargs[key] = value
213 return dotlist_to_nested_dict(filtered_kwargs, sep=sep)
216def is_numeric_consecutive(lst: list[str]) -> bool:
217 """Check if the list of keys is numeric and consecutive."""
218 try:
219 numbers: list[int] = [int(x) for x in lst]
220 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1))
221 except ValueError:
222 return False
225def condense_nested_dicts_numeric_keys(
226 data: dict[str, Any],
227) -> dict[str, Any]:
228 """condense a nested dict, by condensing numeric keys with matching values to ranges
230 # Examples:
231 ```python
232 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
233 {'[1-3]': 1, '[4-6]': 2}
234 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
235 {"1": {"[1-2]": "a"}, "2": "b"}
236 ```
237 """
239 if not isinstance(data, dict):
240 return data
242 # Process each sub-dictionary
243 for key, value in list(data.items()):
244 data[key] = condense_nested_dicts_numeric_keys(value)
246 # Find all numeric, consecutive keys
247 if is_numeric_consecutive(list(data.keys())):
248 keys: list[str] = sorted(data.keys(), key=lambda x: int(x))
249 else:
250 return data
252 # output dict
253 condensed_data: dict[str, Any] = {}
255 # Identify ranges of identical values and condense
256 i: int = 0
257 while i < len(keys):
258 j: int = i
259 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]:
260 j += 1
261 if j > i: # Found consecutive keys with identical values
262 condensed_key: str = f"[{keys[i]}-{keys[j]}]"
263 condensed_data[condensed_key] = data[keys[i]]
264 i = j + 1
265 else:
266 condensed_data[keys[i]] = data[keys[i]]
267 i += 1
269 return condensed_data
272def condense_nested_dicts_matching_values(
273 data: dict[str, Any],
274 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
275) -> dict[str, Any]:
276 """condense a nested dict, by condensing keys with matching values
278 # Examples: TODO
280 # Parameters:
281 - `data : dict[str, Any]`
282 data to process
283 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
284 a function to apply to each value before adding it to the dict (if it's not hashable)
285 (defaults to `None`)
287 """
289 if isinstance(data, dict):
290 data = {
291 key: condense_nested_dicts_matching_values(
292 value, val_condense_fallback_mapping
293 )
294 for key, value in data.items()
295 }
296 else:
297 return data
299 # Find all identical values and condense by stitching together keys
300 values_grouped: defaultdict[Any, list[str]] = defaultdict(list)
301 data_persist: dict[str, Any] = dict()
302 for key, value in data.items():
303 if not isinstance(value, dict):
304 try:
305 values_grouped[value].append(key)
306 except TypeError:
307 # If the value is unhashable, use a fallback mapping to find a hashable representation
308 if val_condense_fallback_mapping is not None:
309 values_grouped[val_condense_fallback_mapping(value)].append(key)
310 else:
311 data_persist[key] = value
312 else:
313 data_persist[key] = value
315 condensed_data = data_persist
316 for value, keys in values_grouped.items():
317 if len(keys) > 1:
318 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys
319 condensed_data[merged_key] = value
320 else:
321 condensed_data[keys[0]] = value
323 return condensed_data
326def condense_nested_dicts(
327 data: dict[str, Any],
328 condense_numeric_keys: bool = True,
329 condense_matching_values: bool = True,
330 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
331) -> dict[str, Any]:
332 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges
334 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()`
336 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
337 it's not reversible because types are lost to make the printing pretty
339 # Parameters:
340 - `data : dict[str, Any]`
341 data to process
342 - `condense_numeric_keys : bool`
343 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]")
344 (defaults to `True`)
345 - `condense_matching_values : bool`
346 whether to condense keys with matching values
347 (defaults to `True`)
348 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
349 a function to apply to each value before adding it to the dict (if it's not hashable)
350 (defaults to `None`)
352 """
354 condensed_data: dict = data
355 if condense_numeric_keys:
356 condensed_data = condense_nested_dicts_numeric_keys(condensed_data)
357 if condense_matching_values:
358 condensed_data = condense_nested_dicts_matching_values(
359 condensed_data, val_condense_fallback_mapping
360 )
361 return condensed_data
364def tuple_dims_replace(
365 t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None
366) -> tuple[Union[int, str], ...]:
367 if dims_names_map is None:
368 return t
369 else:
370 return tuple(dims_names_map.get(x, x) for x in t)
373TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821
374TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821
375TensorDictFormats = Literal["dict", "json", "yaml", "yml"]
378def _default_shapes_convert(x: tuple) -> str:
379 return str(x).replace('"', "").replace("'", "")
382def condense_tensor_dict(
383 data: TensorDict | TensorIterable,
384 fmt: TensorDictFormats = "dict",
385 *args: Any,
386 shapes_convert: Callable[
387 [tuple[Union[int, str], ...]], Any
388 ] = _default_shapes_convert,
389 drop_batch_dims: int = 0,
390 sep: str = ".",
391 dims_names_map: Optional[dict[int, str]] = None,
392 condense_numeric_keys: bool = True,
393 condense_matching_values: bool = True,
394 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
395 return_format: Optional[TensorDictFormats] = None,
396) -> Union[str, dict[str, str | tuple[int, ...]]]:
397 """Convert a dictionary of tensors to a dictionary of shapes.
399 by default, values are converted to strings of their shapes (for nice printing).
400 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`.
402 # Parameters:
403 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]`
404 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` )
405 - `fmt : TensorDictFormats`
406 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed.
407 (defaults to `'dict'`)
408 - `shapes_convert : Callable[[tuple], Any]`
409 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes)
410 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`)
411 - `drop_batch_dims : int`
412 number of leading dimensions to drop from the shape
413 (defaults to `0`)
414 - `sep : str`
415 separator to use for nested keys
416 (defaults to `'.'`)
417 - `dims_names_map : dict[int, str] | None`
418 convert certain dimension values in shape. not perfect, can be buggy
419 (defaults to `None`)
420 - `condense_numeric_keys : bool`
421 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts`
422 (defaults to `True`)
423 - `condense_matching_values : bool`
424 whether to condense keys with matching values, passed on to `condense_nested_dicts`
425 (defaults to `True`)
426 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
427 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts`
428 (defaults to `None`)
429 - `return_format : TensorDictFormats | None`
430 legacy alias for `fmt` kwarg
432 # Returns:
433 - `str|dict[str, str|tuple[int, ...]]`
434 dict if `return_format='dict'`, a string for `json` or `yaml` output
436 # Examples:
437 ```python
438 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
439 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
440 ```
441 ```yaml
442 embed:
443 W_E: (50257, 768)
444 pos_embed:
445 W_pos: (1024, 768)
446 blocks:
447 '[0-11]':
448 attn:
449 '[W_Q, W_K, W_V]': (12, 768, 64)
450 W_O: (12, 64, 768)
451 '[b_Q, b_K, b_V]': (12, 64)
452 b_O: (768,)
453 mlp:
454 W_in: (768, 3072)
455 b_in: (3072,)
456 W_out: (3072, 768)
457 b_out: (768,)
458 unembed:
459 W_U: (768, 50257)
460 b_U: (50257,)
461 ```
463 # Raises:
464 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed
465 """
467 # handle arg processing:
468 # ----------------------------------------------------------------------
469 # make all args except data and format keyword-only
470 assert len(args) == 0, f"unexpected positional args: {args}"
471 # handle legacy return_format
472 if return_format is not None:
473 warnings.warn(
474 "return_format is deprecated, use fmt instead",
475 DeprecationWarning,
476 )
477 fmt = return_format
479 # identity function for shapes_convert if not provided
480 if shapes_convert is None:
481 shapes_convert = lambda x: x # noqa: E731
483 # convert to iterable
484 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821
485 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore
486 )
488 # get shapes
489 data_shapes: dict[str, Union[str, tuple[int, ...]]] = { # pyright: ignore[reportAssignmentType]
490 k: shapes_convert(
491 tuple_dims_replace(
492 tuple(v.shape)[drop_batch_dims:],
493 dims_names_map,
494 )
495 )
496 for k, v in data_items
497 }
499 # nest the dict
500 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep)
502 # condense the nested dict
503 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts(
504 data=data_nested,
505 condense_numeric_keys=condense_numeric_keys,
506 condense_matching_values=condense_matching_values,
507 val_condense_fallback_mapping=val_condense_fallback_mapping,
508 )
510 # return in the specified format
511 fmt_lower: str = fmt.lower()
512 if fmt_lower == "dict":
513 return data_condensed
514 elif fmt_lower == "json":
515 import json
517 return json.dumps(data_condensed, indent=2)
518 elif fmt_lower in ["yaml", "yml"]:
519 try:
520 import yaml # type: ignore[import-untyped]
522 return yaml.dump(data_condensed, sort_keys=False)
523 except ImportError as e:
524 raise ValueError("PyYAML is required for YAML output") from e
525 else:
526 raise ValueError(f"Invalid return format: {fmt}")