Coverage for muutils/dictmagic.py: 86%
160 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""making working with dictionaries easier
3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument
4- various methods for working wit dotlist-nested dicts, converting to and from them
5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges
6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes
7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict
8"""
10from __future__ import annotations
12import typing
13import warnings
14from collections import defaultdict
15from typing import (
16 Any,
17 Callable,
18 Generic,
19 Hashable,
20 Iterable,
21 Literal,
22 Optional,
23 TypeVar,
24 Union,
25)
27from muutils.errormode import ErrorMode
29_KT = TypeVar("_KT")
30_VT = TypeVar("_VT")
33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]):
34 """like a defaultdict, but default_factory is passed the key as an argument"""
36 def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs):
37 if args:
38 raise TypeError(
39 f"DefaulterDict does not support positional arguments: *args = {args}"
40 )
41 super().__init__(**kwargs)
42 self.default_factory: Callable[[_KT], _VT] = default_factory
44 def __getitem__(self, k: _KT) -> _VT:
45 if k in self:
46 return dict.__getitem__(self, k)
47 else:
48 v: _VT = self.default_factory(k)
49 dict.__setitem__(self, k, v)
50 return v
53def _recursive_defaultdict_ctor() -> defaultdict:
54 return defaultdict(_recursive_defaultdict_ctor)
57def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict:
58 """Convert a defaultdict or DefaulterDict to a normal dict, recursively"""
59 return {
60 key: (
61 defaultdict_to_dict_recursive(value)
62 if isinstance(value, (defaultdict, DefaulterDict))
63 else value
64 )
65 for key, value in dd.items()
66 }
69def dotlist_to_nested_dict(
70 dot_dict: typing.Dict[str, Any], sep: str = "."
71) -> typing.Dict[str, Any]:
72 """Convert a dict with dot-separated keys to a nested dict
74 Example:
76 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
77 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
78 """
79 nested_dict: defaultdict = _recursive_defaultdict_ctor()
80 for key, value in dot_dict.items():
81 if not isinstance(key, str):
82 raise TypeError(f"key must be a string, got {type(key)}")
83 keys: list[str] = key.split(sep)
84 current: defaultdict = nested_dict
85 # iterate over the keys except the last one
86 for sub_key in keys[:-1]:
87 current = current[sub_key]
88 current[keys[-1]] = value
89 return defaultdict_to_dict_recursive(nested_dict)
92def nested_dict_to_dotlist(
93 nested_dict: typing.Dict[str, Any],
94 sep: str = ".",
95 allow_lists: bool = False,
96) -> dict[str, Any]:
97 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]:
98 items: dict = dict()
100 new_key: str
101 if isinstance(current, dict):
102 # dict case
103 if not current and parent_key:
104 items[parent_key] = current
105 else:
106 for k, v in current.items():
107 new_key = f"{parent_key}{sep}{k}" if parent_key else k
108 items.update(_recurse(v, new_key))
110 elif allow_lists and isinstance(current, list):
111 # list case
112 for i, item in enumerate(current):
113 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
114 items.update(_recurse(item, new_key))
116 else:
117 # anything else (write value)
118 items[parent_key] = current
120 return items
122 return _recurse(nested_dict)
125def update_with_nested_dict(
126 original: dict[str, Any],
127 update: dict[str, Any],
128) -> dict[str, Any]:
129 """Update a dict with a nested dict
131 Example:
132 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
133 {'a': {'b': 2}, 'c': -1}
135 # Arguments
136 - `original: dict[str, Any]`
137 the dict to update (will be modified in-place)
138 - `update: dict[str, Any]`
139 the dict to update with
141 # Returns
142 - `dict`
143 the updated dict
144 """
145 for key, value in update.items():
146 if key in original:
147 if isinstance(original[key], dict) and isinstance(value, dict):
148 update_with_nested_dict(original[key], value)
149 else:
150 original[key] = value
151 else:
152 original[key] = value
154 return original
157def kwargs_to_nested_dict(
158 kwargs_dict: dict[str, Any],
159 sep: str = ".",
160 strip_prefix: Optional[str] = None,
161 when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN,
162 transform_key: Optional[Callable[[str], str]] = None,
163) -> dict[str, Any]:
164 """given kwargs from fire, convert them to a nested dict
166 if strip_prefix is not None, then all keys must start with the prefix. by default,
167 will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
168 `when_unknown_prefix: ErrorMode`
170 Example:
171 ```python
172 def main(**kwargs):
173 print(kwargs_to_nested_dict(kwargs))
174 fire.Fire(main)
175 ```
176 running the above script will give:
177 ```bash
178 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
179 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
180 ```
182 # Arguments
183 - `kwargs_dict: dict[str, Any]`
184 the kwargs dict to convert
185 - `sep: str = "."`
186 the separator to use for nested keys
187 - `strip_prefix: Optional[str] = None`
188 if not None, then all keys must start with this prefix
189 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN`
190 what to do when an unknown prefix is found
191 - `transform_key: Callable[[str], str] | None = None`
192 a function to apply to each key before adding it to the dict (applied after stripping the prefix)
193 """
194 when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix)
195 filtered_kwargs: dict[str, Any] = dict()
196 for key, value in kwargs_dict.items():
197 if strip_prefix is not None:
198 if not key.startswith(strip_prefix):
199 when_unknown_prefix_.process(
200 f"key '{key}' does not start with '{strip_prefix}'",
201 except_cls=ValueError,
202 )
203 else:
204 key = key[len(strip_prefix) :]
206 if transform_key is not None:
207 key = transform_key(key)
209 filtered_kwargs[key] = value
211 return dotlist_to_nested_dict(filtered_kwargs, sep=sep)
214def is_numeric_consecutive(lst: list[str]) -> bool:
215 """Check if the list of keys is numeric and consecutive."""
216 try:
217 numbers: list[int] = [int(x) for x in lst]
218 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1))
219 except ValueError:
220 return False
223def condense_nested_dicts_numeric_keys(
224 data: dict[str, Any],
225) -> dict[str, Any]:
226 """condense a nested dict, by condensing numeric keys with matching values to ranges
228 # Examples:
229 ```python
230 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
231 {'[1-3]': 1, '[4-6]': 2}
232 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
233 {"1": {"[1-2]": "a"}, "2": "b"}
234 ```
235 """
237 if not isinstance(data, dict):
238 return data
240 # Process each sub-dictionary
241 for key, value in list(data.items()):
242 data[key] = condense_nested_dicts_numeric_keys(value)
244 # Find all numeric, consecutive keys
245 if is_numeric_consecutive(list(data.keys())):
246 keys: list[str] = sorted(data.keys(), key=lambda x: int(x))
247 else:
248 return data
250 # output dict
251 condensed_data: dict[str, Any] = {}
253 # Identify ranges of identical values and condense
254 i: int = 0
255 while i < len(keys):
256 j: int = i
257 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]:
258 j += 1
259 if j > i: # Found consecutive keys with identical values
260 condensed_key: str = f"[{keys[i]}-{keys[j]}]"
261 condensed_data[condensed_key] = data[keys[i]]
262 i = j + 1
263 else:
264 condensed_data[keys[i]] = data[keys[i]]
265 i += 1
267 return condensed_data
270def condense_nested_dicts_matching_values(
271 data: dict[str, Any],
272 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
273) -> dict[str, Any]:
274 """condense a nested dict, by condensing keys with matching values
276 # Examples: TODO
278 # Parameters:
279 - `data : dict[str, Any]`
280 data to process
281 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
282 a function to apply to each value before adding it to the dict (if it's not hashable)
283 (defaults to `None`)
285 """
287 if isinstance(data, dict):
288 data = {
289 key: condense_nested_dicts_matching_values(
290 value, val_condense_fallback_mapping
291 )
292 for key, value in data.items()
293 }
294 else:
295 return data
297 # Find all identical values and condense by stitching together keys
298 values_grouped: defaultdict[Any, list[str]] = defaultdict(list)
299 data_persist: dict[str, Any] = dict()
300 for key, value in data.items():
301 if not isinstance(value, dict):
302 try:
303 values_grouped[value].append(key)
304 except TypeError:
305 # If the value is unhashable, use a fallback mapping to find a hashable representation
306 if val_condense_fallback_mapping is not None:
307 values_grouped[val_condense_fallback_mapping(value)].append(key)
308 else:
309 data_persist[key] = value
310 else:
311 data_persist[key] = value
313 condensed_data = data_persist
314 for value, keys in values_grouped.items():
315 if len(keys) > 1:
316 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys
317 condensed_data[merged_key] = value
318 else:
319 condensed_data[keys[0]] = value
321 return condensed_data
324def condense_nested_dicts(
325 data: dict[str, Any],
326 condense_numeric_keys: bool = True,
327 condense_matching_values: bool = True,
328 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
329) -> dict[str, Any]:
330 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges
332 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()`
334 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
335 it's not reversible because types are lost to make the printing pretty
337 # Parameters:
338 - `data : dict[str, Any]`
339 data to process
340 - `condense_numeric_keys : bool`
341 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]")
342 (defaults to `True`)
343 - `condense_matching_values : bool`
344 whether to condense keys with matching values
345 (defaults to `True`)
346 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
347 a function to apply to each value before adding it to the dict (if it's not hashable)
348 (defaults to `None`)
350 """
352 condensed_data: dict = data
353 if condense_numeric_keys:
354 condensed_data = condense_nested_dicts_numeric_keys(condensed_data)
355 if condense_matching_values:
356 condensed_data = condense_nested_dicts_matching_values(
357 condensed_data, val_condense_fallback_mapping
358 )
359 return condensed_data
362def tuple_dims_replace(
363 t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None
364) -> tuple[Union[int, str], ...]:
365 if dims_names_map is None:
366 return t
367 else:
368 return tuple(dims_names_map.get(x, x) for x in t)
371TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821
372TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821
373TensorDictFormats = Literal["dict", "json", "yaml", "yml"]
376def _default_shapes_convert(x: tuple) -> str:
377 return str(x).replace('"', "").replace("'", "")
380def condense_tensor_dict(
381 data: TensorDict | TensorIterable,
382 fmt: TensorDictFormats = "dict",
383 *args,
384 shapes_convert: Callable[[tuple], Any] = _default_shapes_convert,
385 drop_batch_dims: int = 0,
386 sep: str = ".",
387 dims_names_map: Optional[dict[int, str]] = None,
388 condense_numeric_keys: bool = True,
389 condense_matching_values: bool = True,
390 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
391 return_format: Optional[TensorDictFormats] = None,
392) -> Union[str, dict[str, str | tuple[int, ...]]]:
393 """Convert a dictionary of tensors to a dictionary of shapes.
395 by default, values are converted to strings of their shapes (for nice printing).
396 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`.
398 # Parameters:
399 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]`
400 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` )
401 - `fmt : TensorDictFormats`
402 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed.
403 (defaults to `'dict'`)
404 - `shapes_convert : Callable[[tuple], Any]`
405 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes)
406 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`)
407 - `drop_batch_dims : int`
408 number of leading dimensions to drop from the shape
409 (defaults to `0`)
410 - `sep : str`
411 separator to use for nested keys
412 (defaults to `'.'`)
413 - `dims_names_map : dict[int, str] | None`
414 convert certain dimension values in shape. not perfect, can be buggy
415 (defaults to `None`)
416 - `condense_numeric_keys : bool`
417 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts`
418 (defaults to `True`)
419 - `condense_matching_values : bool`
420 whether to condense keys with matching values, passed on to `condense_nested_dicts`
421 (defaults to `True`)
422 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
423 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts`
424 (defaults to `None`)
425 - `return_format : TensorDictFormats | None`
426 legacy alias for `fmt` kwarg
428 # Returns:
429 - `str|dict[str, str|tuple[int, ...]]`
430 dict if `return_format='dict'`, a string for `json` or `yaml` output
432 # Examples:
433 ```python
434 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
435 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
436 ```
437 ```yaml
438 embed:
439 W_E: (50257, 768)
440 pos_embed:
441 W_pos: (1024, 768)
442 blocks:
443 '[0-11]':
444 attn:
445 '[W_Q, W_K, W_V]': (12, 768, 64)
446 W_O: (12, 64, 768)
447 '[b_Q, b_K, b_V]': (12, 64)
448 b_O: (768,)
449 mlp:
450 W_in: (768, 3072)
451 b_in: (3072,)
452 W_out: (3072, 768)
453 b_out: (768,)
454 unembed:
455 W_U: (768, 50257)
456 b_U: (50257,)
457 ```
459 # Raises:
460 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed
461 """
463 # handle arg processing:
464 # ----------------------------------------------------------------------
465 # make all args except data and format keyword-only
466 assert len(args) == 0, f"unexpected positional args: {args}"
467 # handle legacy return_format
468 if return_format is not None:
469 warnings.warn(
470 "return_format is deprecated, use fmt instead",
471 DeprecationWarning,
472 )
473 fmt = return_format
475 # identity function for shapes_convert if not provided
476 if shapes_convert is None:
477 shapes_convert = lambda x: x # noqa: E731
479 # convert to iterable
480 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821
481 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore
482 )
484 # get shapes
485 data_shapes: dict[str, Union[str, tuple[int, ...]]] = {
486 k: shapes_convert(
487 tuple_dims_replace(
488 tuple(v.shape)[drop_batch_dims:],
489 dims_names_map,
490 )
491 )
492 for k, v in data_items
493 }
495 # nest the dict
496 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep)
498 # condense the nested dict
499 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts(
500 data=data_nested,
501 condense_numeric_keys=condense_numeric_keys,
502 condense_matching_values=condense_matching_values,
503 val_condense_fallback_mapping=val_condense_fallback_mapping,
504 )
506 # return in the specified format
507 fmt_lower: str = fmt.lower()
508 if fmt_lower == "dict":
509 return data_condensed
510 elif fmt_lower == "json":
511 import json
513 return json.dumps(data_condensed, indent=2)
514 elif fmt_lower in ["yaml", "yml"]:
515 try:
516 import yaml # type: ignore[import-untyped]
518 return yaml.dump(data_condensed, sort_keys=False)
519 except ImportError as e:
520 raise ValueError("PyYAML is required for YAML output") from e
521 else:
522 raise ValueError(f"Invalid return format: {fmt}")