Coverage for muutils / dictmagic.py: 86%

160 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1"""making working with dictionaries easier 

2 

3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument 

4- various methods for working wit dotlist-nested dicts, converting to and from them 

5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges 

6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes 

7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict 

8""" 

9 

10from __future__ import annotations 

11 

12import typing 

13import warnings 

14from collections import defaultdict 

15from typing import ( 

16 Any, 

17 Callable, 

18 Generic, 

19 Hashable, 

20 Iterable, 

21 Literal, 

22 Optional, 

23 TypeVar, 

24 Union, 

25) 

26 

27from muutils.errormode import ErrorMode 

28 

29_KT = TypeVar("_KT") 

30_VT = TypeVar("_VT") 

31 

32 

33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): 

34 """like a defaultdict, but default_factory is passed the key as an argument""" 

35 

36 def __init__( 

37 self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any 

38 ) -> None: 

39 if args: 

40 raise TypeError( 

41 f"DefaulterDict does not support positional arguments: *args = {args}" 

42 ) 

43 super().__init__(**kwargs) 

44 self.default_factory: Callable[[_KT], _VT] = default_factory 

45 

46 def __getitem__(self, k: _KT) -> _VT: 

47 if k in self: 

48 return dict.__getitem__(self, k) 

49 else: 

50 v: _VT = self.default_factory(k) 

51 dict.__setitem__(self, k, v) 

52 return v 

53 

54 

55def _recursive_defaultdict_ctor() -> defaultdict: 

56 return defaultdict(_recursive_defaultdict_ctor) 

57 

58 

59def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: 

60 """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" 

61 return { 

62 key: ( 

63 defaultdict_to_dict_recursive(value) 

64 if isinstance(value, (defaultdict, DefaulterDict)) 

65 else value 

66 ) 

67 for key, value in dd.items() 

68 } 

69 

70 

71def dotlist_to_nested_dict( 

72 dot_dict: typing.Dict[str, Any], sep: str = "." 

73) -> typing.Dict[str, Any]: 

74 """Convert a dict with dot-separated keys to a nested dict 

75 

76 Example: 

77 

78 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) 

79 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 

80 """ 

81 nested_dict: defaultdict = _recursive_defaultdict_ctor() 

82 for key, value in dot_dict.items(): 

83 if not isinstance(key, str): 

84 raise TypeError(f"key must be a string, got {type(key)}") 

85 keys: list[str] = key.split(sep) 

86 current: defaultdict = nested_dict 

87 # iterate over the keys except the last one 

88 for sub_key in keys[:-1]: 

89 current = current[sub_key] 

90 current[keys[-1]] = value 

91 return defaultdict_to_dict_recursive(nested_dict) 

92 

93 

94def nested_dict_to_dotlist( 

95 nested_dict: typing.Dict[str, Any], 

96 sep: str = ".", 

97 allow_lists: bool = False, 

98) -> dict[str, Any]: 

99 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: 

100 items: dict = dict() 

101 

102 new_key: str 

103 if isinstance(current, dict): 

104 # dict case 

105 if not current and parent_key: 

106 items[parent_key] = current 

107 else: 

108 for k, v in current.items(): 

109 new_key = f"{parent_key}{sep}{k}" if parent_key else k 

110 items.update(_recurse(v, new_key)) 

111 

112 elif allow_lists and isinstance(current, list): 

113 # list case 

114 for i, item in enumerate(current): 

115 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) 

116 items.update(_recurse(item, new_key)) 

117 

118 else: 

119 # anything else (write value) 

120 items[parent_key] = current 

121 

122 return items 

123 

124 return _recurse(nested_dict) 

125 

126 

127def update_with_nested_dict( 

128 original: dict[str, Any], 

129 update: dict[str, Any], 

130) -> dict[str, Any]: 

131 """Update a dict with a nested dict 

132 

133 Example: 

134 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}}) 

135 {'a': {'b': 2}, 'c': -1} 

136 

137 # Arguments 

138 - `original: dict[str, Any]` 

139 the dict to update (will be modified in-place) 

140 - `update: dict[str, Any]` 

141 the dict to update with 

142 

143 # Returns 

144 - `dict` 

145 the updated dict 

146 """ 

147 for key, value in update.items(): 

148 if key in original: 

149 if isinstance(original[key], dict) and isinstance(value, dict): 

150 update_with_nested_dict(original[key], value) 

151 else: 

152 original[key] = value 

153 else: 

154 original[key] = value 

155 

156 return original 

157 

158 

159def kwargs_to_nested_dict( 

160 kwargs_dict: dict[str, Any], 

161 sep: str = ".", 

162 strip_prefix: Optional[str] = None, 

163 when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN, 

164 transform_key: Optional[Callable[[str], str]] = None, 

165) -> dict[str, Any]: 

166 """given kwargs from fire, convert them to a nested dict 

167 

168 if strip_prefix is not None, then all keys must start with the prefix. by default, 

169 will warn if an unknown prefix is found, but can be set to raise an error or ignore it: 

170 `when_unknown_prefix: ErrorMode` 

171 

172 Example: 

173 ```python 

174 def main(**kwargs): 

175 print(kwargs_to_nested_dict(kwargs)) 

176 fire.Fire(main) 

177 ``` 

178 running the above script will give: 

179 ```bash 

180 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3 

181 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 

182 ``` 

183 

184 # Arguments 

185 - `kwargs_dict: dict[str, Any]` 

186 the kwargs dict to convert 

187 - `sep: str = "."` 

188 the separator to use for nested keys 

189 - `strip_prefix: Optional[str] = None` 

190 if not None, then all keys must start with this prefix 

191 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` 

192 what to do when an unknown prefix is found 

193 - `transform_key: Callable[[str], str] | None = None` 

194 a function to apply to each key before adding it to the dict (applied after stripping the prefix) 

195 """ 

196 when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix) 

197 filtered_kwargs: dict[str, Any] = dict() 

198 for key, value in kwargs_dict.items(): 

199 if strip_prefix is not None: 

200 if not key.startswith(strip_prefix): 

201 when_unknown_prefix_.process( 

202 f"key '{key}' does not start with '{strip_prefix}'", 

203 except_cls=ValueError, 

204 ) 

205 else: 

206 key = key[len(strip_prefix) :] 

207 

208 if transform_key is not None: 

209 key = transform_key(key) 

210 

211 filtered_kwargs[key] = value 

212 

213 return dotlist_to_nested_dict(filtered_kwargs, sep=sep) 

214 

215 

216def is_numeric_consecutive(lst: list[str]) -> bool: 

217 """Check if the list of keys is numeric and consecutive.""" 

218 try: 

219 numbers: list[int] = [int(x) for x in lst] 

220 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1)) 

221 except ValueError: 

222 return False 

223 

224 

225def condense_nested_dicts_numeric_keys( 

226 data: dict[str, Any], 

227) -> dict[str, Any]: 

228 """condense a nested dict, by condensing numeric keys with matching values to ranges 

229 

230 # Examples: 

231 ```python 

232 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2}) 

233 {'[1-3]': 1, '[4-6]': 2} 

234 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'}) 

235 {"1": {"[1-2]": "a"}, "2": "b"} 

236 ``` 

237 """ 

238 

239 if not isinstance(data, dict): 

240 return data 

241 

242 # Process each sub-dictionary 

243 for key, value in list(data.items()): 

244 data[key] = condense_nested_dicts_numeric_keys(value) 

245 

246 # Find all numeric, consecutive keys 

247 if is_numeric_consecutive(list(data.keys())): 

248 keys: list[str] = sorted(data.keys(), key=lambda x: int(x)) 

249 else: 

250 return data 

251 

252 # output dict 

253 condensed_data: dict[str, Any] = {} 

254 

255 # Identify ranges of identical values and condense 

256 i: int = 0 

257 while i < len(keys): 

258 j: int = i 

259 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]: 

260 j += 1 

261 if j > i: # Found consecutive keys with identical values 

262 condensed_key: str = f"[{keys[i]}-{keys[j]}]" 

263 condensed_data[condensed_key] = data[keys[i]] 

264 i = j + 1 

265 else: 

266 condensed_data[keys[i]] = data[keys[i]] 

267 i += 1 

268 

269 return condensed_data 

270 

271 

272def condense_nested_dicts_matching_values( 

273 data: dict[str, Any], 

274 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 

275) -> dict[str, Any]: 

276 """condense a nested dict, by condensing keys with matching values 

277 

278 # Examples: TODO 

279 

280 # Parameters: 

281 - `data : dict[str, Any]` 

282 data to process 

283 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 

284 a function to apply to each value before adding it to the dict (if it's not hashable) 

285 (defaults to `None`) 

286 

287 """ 

288 

289 if isinstance(data, dict): 

290 data = { 

291 key: condense_nested_dicts_matching_values( 

292 value, val_condense_fallback_mapping 

293 ) 

294 for key, value in data.items() 

295 } 

296 else: 

297 return data 

298 

299 # Find all identical values and condense by stitching together keys 

300 values_grouped: defaultdict[Any, list[str]] = defaultdict(list) 

301 data_persist: dict[str, Any] = dict() 

302 for key, value in data.items(): 

303 if not isinstance(value, dict): 

304 try: 

305 values_grouped[value].append(key) 

306 except TypeError: 

307 # If the value is unhashable, use a fallback mapping to find a hashable representation 

308 if val_condense_fallback_mapping is not None: 

309 values_grouped[val_condense_fallback_mapping(value)].append(key) 

310 else: 

311 data_persist[key] = value 

312 else: 

313 data_persist[key] = value 

314 

315 condensed_data = data_persist 

316 for value, keys in values_grouped.items(): 

317 if len(keys) > 1: 

318 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys 

319 condensed_data[merged_key] = value 

320 else: 

321 condensed_data[keys[0]] = value 

322 

323 return condensed_data 

324 

325 

326def condense_nested_dicts( 

327 data: dict[str, Any], 

328 condense_numeric_keys: bool = True, 

329 condense_matching_values: bool = True, 

330 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 

331) -> dict[str, Any]: 

332 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges 

333 

334 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` 

335 

336 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes 

337 it's not reversible because types are lost to make the printing pretty 

338 

339 # Parameters: 

340 - `data : dict[str, Any]` 

341 data to process 

342 - `condense_numeric_keys : bool` 

343 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") 

344 (defaults to `True`) 

345 - `condense_matching_values : bool` 

346 whether to condense keys with matching values 

347 (defaults to `True`) 

348 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 

349 a function to apply to each value before adding it to the dict (if it's not hashable) 

350 (defaults to `None`) 

351 

352 """ 

353 

354 condensed_data: dict = data 

355 if condense_numeric_keys: 

356 condensed_data = condense_nested_dicts_numeric_keys(condensed_data) 

357 if condense_matching_values: 

358 condensed_data = condense_nested_dicts_matching_values( 

359 condensed_data, val_condense_fallback_mapping 

360 ) 

361 return condensed_data 

362 

363 

364def tuple_dims_replace( 

365 t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None 

366) -> tuple[Union[int, str], ...]: 

367 if dims_names_map is None: 

368 return t 

369 else: 

370 return tuple(dims_names_map.get(x, x) for x in t) 

371 

372 

373TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821 

374TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821 

375TensorDictFormats = Literal["dict", "json", "yaml", "yml"] 

376 

377 

378def _default_shapes_convert(x: tuple) -> str: 

379 return str(x).replace('"', "").replace("'", "") 

380 

381 

382def condense_tensor_dict( 

383 data: TensorDict | TensorIterable, 

384 fmt: TensorDictFormats = "dict", 

385 *args: Any, 

386 shapes_convert: Callable[ 

387 [tuple[Union[int, str], ...]], Any 

388 ] = _default_shapes_convert, 

389 drop_batch_dims: int = 0, 

390 sep: str = ".", 

391 dims_names_map: Optional[dict[int, str]] = None, 

392 condense_numeric_keys: bool = True, 

393 condense_matching_values: bool = True, 

394 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 

395 return_format: Optional[TensorDictFormats] = None, 

396) -> Union[str, dict[str, str | tuple[int, ...]]]: 

397 """Convert a dictionary of tensors to a dictionary of shapes. 

398 

399 by default, values are converted to strings of their shapes (for nice printing). 

400 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. 

401 

402 # Parameters: 

403 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` 

404 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) 

405 - `fmt : TensorDictFormats` 

406 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. 

407 (defaults to `'dict'`) 

408 - `shapes_convert : Callable[[tuple], Any]` 

409 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) 

410 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`) 

411 - `drop_batch_dims : int` 

412 number of leading dimensions to drop from the shape 

413 (defaults to `0`) 

414 - `sep : str` 

415 separator to use for nested keys 

416 (defaults to `'.'`) 

417 - `dims_names_map : dict[int, str] | None` 

418 convert certain dimension values in shape. not perfect, can be buggy 

419 (defaults to `None`) 

420 - `condense_numeric_keys : bool` 

421 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts` 

422 (defaults to `True`) 

423 - `condense_matching_values : bool` 

424 whether to condense keys with matching values, passed on to `condense_nested_dicts` 

425 (defaults to `True`) 

426 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 

427 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts` 

428 (defaults to `None`) 

429 - `return_format : TensorDictFormats | None` 

430 legacy alias for `fmt` kwarg 

431 

432 # Returns: 

433 - `str|dict[str, str|tuple[int, ...]]` 

434 dict if `return_format='dict'`, a string for `json` or `yaml` output 

435 

436 # Examples: 

437 ```python 

438 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2") 

439 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml')) 

440 ``` 

441 ```yaml 

442 embed: 

443 W_E: (50257, 768) 

444 pos_embed: 

445 W_pos: (1024, 768) 

446 blocks: 

447 '[0-11]': 

448 attn: 

449 '[W_Q, W_K, W_V]': (12, 768, 64) 

450 W_O: (12, 64, 768) 

451 '[b_Q, b_K, b_V]': (12, 64) 

452 b_O: (768,) 

453 mlp: 

454 W_in: (768, 3072) 

455 b_in: (3072,) 

456 W_out: (3072, 768) 

457 b_out: (768,) 

458 unembed: 

459 W_U: (768, 50257) 

460 b_U: (50257,) 

461 ``` 

462 

463 # Raises: 

464 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed 

465 """ 

466 

467 # handle arg processing: 

468 # ---------------------------------------------------------------------- 

469 # make all args except data and format keyword-only 

470 assert len(args) == 0, f"unexpected positional args: {args}" 

471 # handle legacy return_format 

472 if return_format is not None: 

473 warnings.warn( 

474 "return_format is deprecated, use fmt instead", 

475 DeprecationWarning, 

476 ) 

477 fmt = return_format 

478 

479 # identity function for shapes_convert if not provided 

480 if shapes_convert is None: 

481 shapes_convert = lambda x: x # noqa: E731 

482 

483 # convert to iterable 

484 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 

485 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore 

486 ) 

487 

488 # get shapes 

489 data_shapes: dict[str, Union[str, tuple[int, ...]]] = { # pyright: ignore[reportAssignmentType] 

490 k: shapes_convert( 

491 tuple_dims_replace( 

492 tuple(v.shape)[drop_batch_dims:], 

493 dims_names_map, 

494 ) 

495 ) 

496 for k, v in data_items 

497 } 

498 

499 # nest the dict 

500 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) 

501 

502 # condense the nested dict 

503 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( 

504 data=data_nested, 

505 condense_numeric_keys=condense_numeric_keys, 

506 condense_matching_values=condense_matching_values, 

507 val_condense_fallback_mapping=val_condense_fallback_mapping, 

508 ) 

509 

510 # return in the specified format 

511 fmt_lower: str = fmt.lower() 

512 if fmt_lower == "dict": 

513 return data_condensed 

514 elif fmt_lower == "json": 

515 import json 

516 

517 return json.dumps(data_condensed, indent=2) 

518 elif fmt_lower in ["yaml", "yml"]: 

519 try: 

520 import yaml # type: ignore[import-untyped] 

521 

522 return yaml.dump(data_condensed, sort_keys=False) 

523 except ImportError as e: 

524 raise ValueError("PyYAML is required for YAML output") from e 

525 else: 

526 raise ValueError(f"Invalid return format: {fmt}")