Coverage for muutils/dictmagic.py: 86%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""making working with dictionaries easier 

2 

3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument 

4- various methods for working wit dotlist-nested dicts, converting to and from them 

5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges 

6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes 

7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict 

8""" 

9 

10from __future__ import annotations 

11 

12import typing 

13import warnings 

14from collections import defaultdict 

15from typing import ( 

16 Any, 

17 Callable, 

18 Generic, 

19 Hashable, 

20 Iterable, 

21 Literal, 

22 Optional, 

23 TypeVar, 

24 Union, 

25) 

26 

27from muutils.errormode import ErrorMode 

28 

29_KT = TypeVar("_KT") 

30_VT = TypeVar("_VT") 

31 

32 

33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): 

34 """like a defaultdict, but default_factory is passed the key as an argument""" 

35 

36 def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs): 

37 if args: 

38 raise TypeError( 

39 f"DefaulterDict does not support positional arguments: *args = {args}" 

40 ) 

41 super().__init__(**kwargs) 

42 self.default_factory: Callable[[_KT], _VT] = default_factory 

43 

44 def __getitem__(self, k: _KT) -> _VT: 

45 if k in self: 

46 return dict.__getitem__(self, k) 

47 else: 

48 v: _VT = self.default_factory(k) 

49 dict.__setitem__(self, k, v) 

50 return v 

51 

52 

53def _recursive_defaultdict_ctor() -> defaultdict: 

54 return defaultdict(_recursive_defaultdict_ctor) 

55 

56 

57def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: 

58 """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" 

59 return { 

60 key: ( 

61 defaultdict_to_dict_recursive(value) 

62 if isinstance(value, (defaultdict, DefaulterDict)) 

63 else value 

64 ) 

65 for key, value in dd.items() 

66 } 

67 

68 

69def dotlist_to_nested_dict( 

70 dot_dict: typing.Dict[str, Any], sep: str = "." 

71) -> typing.Dict[str, Any]: 

72 """Convert a dict with dot-separated keys to a nested dict 

73 

74 Example: 

75 

76 >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) 

77 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 

78 """ 

79 nested_dict: defaultdict = _recursive_defaultdict_ctor() 

80 for key, value in dot_dict.items(): 

81 if not isinstance(key, str): 

82 raise TypeError(f"key must be a string, got {type(key)}") 

83 keys: list[str] = key.split(sep) 

84 current: defaultdict = nested_dict 

85 # iterate over the keys except the last one 

86 for sub_key in keys[:-1]: 

87 current = current[sub_key] 

88 current[keys[-1]] = value 

89 return defaultdict_to_dict_recursive(nested_dict) 

90 

91 

92def nested_dict_to_dotlist( 

93 nested_dict: typing.Dict[str, Any], 

94 sep: str = ".", 

95 allow_lists: bool = False, 

96) -> dict[str, Any]: 

97 def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: 

98 items: dict = dict() 

99 

100 new_key: str 

101 if isinstance(current, dict): 

102 # dict case 

103 if not current and parent_key: 

104 items[parent_key] = current 

105 else: 

106 for k, v in current.items(): 

107 new_key = f"{parent_key}{sep}{k}" if parent_key else k 

108 items.update(_recurse(v, new_key)) 

109 

110 elif allow_lists and isinstance(current, list): 

111 # list case 

112 for i, item in enumerate(current): 

113 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) 

114 items.update(_recurse(item, new_key)) 

115 

116 else: 

117 # anything else (write value) 

118 items[parent_key] = current 

119 

120 return items 

121 

122 return _recurse(nested_dict) 

123 

124 

125def update_with_nested_dict( 

126 original: dict[str, Any], 

127 update: dict[str, Any], 

128) -> dict[str, Any]: 

129 """Update a dict with a nested dict 

130 

131 Example: 

132 >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}}) 

133 {'a': {'b': 2}, 'c': -1} 

134 

135 # Arguments 

136 - `original: dict[str, Any]` 

137 the dict to update (will be modified in-place) 

138 - `update: dict[str, Any]` 

139 the dict to update with 

140 

141 # Returns 

142 - `dict` 

143 the updated dict 

144 """ 

145 for key, value in update.items(): 

146 if key in original: 

147 if isinstance(original[key], dict) and isinstance(value, dict): 

148 update_with_nested_dict(original[key], value) 

149 else: 

150 original[key] = value 

151 else: 

152 original[key] = value 

153 

154 return original 

155 

156 

157def kwargs_to_nested_dict( 

158 kwargs_dict: dict[str, Any], 

159 sep: str = ".", 

160 strip_prefix: Optional[str] = None, 

161 when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN, 

162 transform_key: Optional[Callable[[str], str]] = None, 

163) -> dict[str, Any]: 

164 """given kwargs from fire, convert them to a nested dict 

165 

166 if strip_prefix is not None, then all keys must start with the prefix. by default, 

167 will warn if an unknown prefix is found, but can be set to raise an error or ignore it: 

168 `when_unknown_prefix: ErrorMode` 

169 

170 Example: 

171 ```python 

172 def main(**kwargs): 

173 print(kwargs_to_nested_dict(kwargs)) 

174 fire.Fire(main) 

175 ``` 

176 running the above script will give: 

177 ```bash 

178 $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3 

179 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} 

180 ``` 

181 

182 # Arguments 

183 - `kwargs_dict: dict[str, Any]` 

184 the kwargs dict to convert 

185 - `sep: str = "."` 

186 the separator to use for nested keys 

187 - `strip_prefix: Optional[str] = None` 

188 if not None, then all keys must start with this prefix 

189 - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` 

190 what to do when an unknown prefix is found 

191 - `transform_key: Callable[[str], str] | None = None` 

192 a function to apply to each key before adding it to the dict (applied after stripping the prefix) 

193 """ 

194 when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix) 

195 filtered_kwargs: dict[str, Any] = dict() 

196 for key, value in kwargs_dict.items(): 

197 if strip_prefix is not None: 

198 if not key.startswith(strip_prefix): 

199 when_unknown_prefix_.process( 

200 f"key '{key}' does not start with '{strip_prefix}'", 

201 except_cls=ValueError, 

202 ) 

203 else: 

204 key = key[len(strip_prefix) :] 

205 

206 if transform_key is not None: 

207 key = transform_key(key) 

208 

209 filtered_kwargs[key] = value 

210 

211 return dotlist_to_nested_dict(filtered_kwargs, sep=sep) 

212 

213 

214def is_numeric_consecutive(lst: list[str]) -> bool: 

215 """Check if the list of keys is numeric and consecutive.""" 

216 try: 

217 numbers: list[int] = [int(x) for x in lst] 

218 return sorted(numbers) == list(range(min(numbers), max(numbers) + 1)) 

219 except ValueError: 

220 return False 

221 

222 

223def condense_nested_dicts_numeric_keys( 

224 data: dict[str, Any], 

225) -> dict[str, Any]: 

226 """condense a nested dict, by condensing numeric keys with matching values to ranges 

227 

228 # Examples: 

229 ```python 

230 >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2}) 

231 {'[1-3]': 1, '[4-6]': 2} 

232 >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'}) 

233 {"1": {"[1-2]": "a"}, "2": "b"} 

234 ``` 

235 """ 

236 

237 if not isinstance(data, dict): 

238 return data 

239 

240 # Process each sub-dictionary 

241 for key, value in list(data.items()): 

242 data[key] = condense_nested_dicts_numeric_keys(value) 

243 

244 # Find all numeric, consecutive keys 

245 if is_numeric_consecutive(list(data.keys())): 

246 keys: list[str] = sorted(data.keys(), key=lambda x: int(x)) 

247 else: 

248 return data 

249 

250 # output dict 

251 condensed_data: dict[str, Any] = {} 

252 

253 # Identify ranges of identical values and condense 

254 i: int = 0 

255 while i < len(keys): 

256 j: int = i 

257 while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]: 

258 j += 1 

259 if j > i: # Found consecutive keys with identical values 

260 condensed_key: str = f"[{keys[i]}-{keys[j]}]" 

261 condensed_data[condensed_key] = data[keys[i]] 

262 i = j + 1 

263 else: 

264 condensed_data[keys[i]] = data[keys[i]] 

265 i += 1 

266 

267 return condensed_data 

268 

269 

270def condense_nested_dicts_matching_values( 

271 data: dict[str, Any], 

272 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 

273) -> dict[str, Any]: 

274 """condense a nested dict, by condensing keys with matching values 

275 

276 # Examples: TODO 

277 

278 # Parameters: 

279 - `data : dict[str, Any]` 

280 data to process 

281 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 

282 a function to apply to each value before adding it to the dict (if it's not hashable) 

283 (defaults to `None`) 

284 

285 """ 

286 

287 if isinstance(data, dict): 

288 data = { 

289 key: condense_nested_dicts_matching_values( 

290 value, val_condense_fallback_mapping 

291 ) 

292 for key, value in data.items() 

293 } 

294 else: 

295 return data 

296 

297 # Find all identical values and condense by stitching together keys 

298 values_grouped: defaultdict[Any, list[str]] = defaultdict(list) 

299 data_persist: dict[str, Any] = dict() 

300 for key, value in data.items(): 

301 if not isinstance(value, dict): 

302 try: 

303 values_grouped[value].append(key) 

304 except TypeError: 

305 # If the value is unhashable, use a fallback mapping to find a hashable representation 

306 if val_condense_fallback_mapping is not None: 

307 values_grouped[val_condense_fallback_mapping(value)].append(key) 

308 else: 

309 data_persist[key] = value 

310 else: 

311 data_persist[key] = value 

312 

313 condensed_data = data_persist 

314 for value, keys in values_grouped.items(): 

315 if len(keys) > 1: 

316 merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys 

317 condensed_data[merged_key] = value 

318 else: 

319 condensed_data[keys[0]] = value 

320 

321 return condensed_data 

322 

323 

324def condense_nested_dicts( 

325 data: dict[str, Any], 

326 condense_numeric_keys: bool = True, 

327 condense_matching_values: bool = True, 

328 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 

329) -> dict[str, Any]: 

330 """condense a nested dict, by condensing numeric or matching keys with matching values to ranges 

331 

332 combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` 

333 

334 # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes 

335 it's not reversible because types are lost to make the printing pretty 

336 

337 # Parameters: 

338 - `data : dict[str, Any]` 

339 data to process 

340 - `condense_numeric_keys : bool` 

341 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") 

342 (defaults to `True`) 

343 - `condense_matching_values : bool` 

344 whether to condense keys with matching values 

345 (defaults to `True`) 

346 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 

347 a function to apply to each value before adding it to the dict (if it's not hashable) 

348 (defaults to `None`) 

349 

350 """ 

351 

352 condensed_data: dict = data 

353 if condense_numeric_keys: 

354 condensed_data = condense_nested_dicts_numeric_keys(condensed_data) 

355 if condense_matching_values: 

356 condensed_data = condense_nested_dicts_matching_values( 

357 condensed_data, val_condense_fallback_mapping 

358 ) 

359 return condensed_data 

360 

361 

362def tuple_dims_replace( 

363 t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None 

364) -> tuple[Union[int, str], ...]: 

365 if dims_names_map is None: 

366 return t 

367 else: 

368 return tuple(dims_names_map.get(x, x) for x in t) 

369 

370 

371TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821 

372TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821 

373TensorDictFormats = Literal["dict", "json", "yaml", "yml"] 

374 

375 

376def _default_shapes_convert(x: tuple) -> str: 

377 return str(x).replace('"', "").replace("'", "") 

378 

379 

380def condense_tensor_dict( 

381 data: TensorDict | TensorIterable, 

382 fmt: TensorDictFormats = "dict", 

383 *args, 

384 shapes_convert: Callable[[tuple], Any] = _default_shapes_convert, 

385 drop_batch_dims: int = 0, 

386 sep: str = ".", 

387 dims_names_map: Optional[dict[int, str]] = None, 

388 condense_numeric_keys: bool = True, 

389 condense_matching_values: bool = True, 

390 val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, 

391 return_format: Optional[TensorDictFormats] = None, 

392) -> Union[str, dict[str, str | tuple[int, ...]]]: 

393 """Convert a dictionary of tensors to a dictionary of shapes. 

394 

395 by default, values are converted to strings of their shapes (for nice printing). 

396 If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. 

397 

398 # Parameters: 

399 - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` 

400 a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) 

401 - `fmt : TensorDictFormats` 

402 format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. 

403 (defaults to `'dict'`) 

404 - `shapes_convert : Callable[[tuple], Any]` 

405 conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) 

406 (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`) 

407 - `drop_batch_dims : int` 

408 number of leading dimensions to drop from the shape 

409 (defaults to `0`) 

410 - `sep : str` 

411 separator to use for nested keys 

412 (defaults to `'.'`) 

413 - `dims_names_map : dict[int, str] | None` 

414 convert certain dimension values in shape. not perfect, can be buggy 

415 (defaults to `None`) 

416 - `condense_numeric_keys : bool` 

417 whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts` 

418 (defaults to `True`) 

419 - `condense_matching_values : bool` 

420 whether to condense keys with matching values, passed on to `condense_nested_dicts` 

421 (defaults to `True`) 

422 - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` 

423 a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts` 

424 (defaults to `None`) 

425 - `return_format : TensorDictFormats | None` 

426 legacy alias for `fmt` kwarg 

427 

428 # Returns: 

429 - `str|dict[str, str|tuple[int, ...]]` 

430 dict if `return_format='dict'`, a string for `json` or `yaml` output 

431 

432 # Examples: 

433 ```python 

434 >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2") 

435 >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml')) 

436 ``` 

437 ```yaml 

438 embed: 

439 W_E: (50257, 768) 

440 pos_embed: 

441 W_pos: (1024, 768) 

442 blocks: 

443 '[0-11]': 

444 attn: 

445 '[W_Q, W_K, W_V]': (12, 768, 64) 

446 W_O: (12, 64, 768) 

447 '[b_Q, b_K, b_V]': (12, 64) 

448 b_O: (768,) 

449 mlp: 

450 W_in: (768, 3072) 

451 b_in: (3072,) 

452 W_out: (3072, 768) 

453 b_out: (768,) 

454 unembed: 

455 W_U: (768, 50257) 

456 b_U: (50257,) 

457 ``` 

458 

459 # Raises: 

460 - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed 

461 """ 

462 

463 # handle arg processing: 

464 # ---------------------------------------------------------------------- 

465 # make all args except data and format keyword-only 

466 assert len(args) == 0, f"unexpected positional args: {args}" 

467 # handle legacy return_format 

468 if return_format is not None: 

469 warnings.warn( 

470 "return_format is deprecated, use fmt instead", 

471 DeprecationWarning, 

472 ) 

473 fmt = return_format 

474 

475 # identity function for shapes_convert if not provided 

476 if shapes_convert is None: 

477 shapes_convert = lambda x: x # noqa: E731 

478 

479 # convert to iterable 

480 data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 

481 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore 

482 ) 

483 

484 # get shapes 

485 data_shapes: dict[str, Union[str, tuple[int, ...]]] = { 

486 k: shapes_convert( 

487 tuple_dims_replace( 

488 tuple(v.shape)[drop_batch_dims:], 

489 dims_names_map, 

490 ) 

491 ) 

492 for k, v in data_items 

493 } 

494 

495 # nest the dict 

496 data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) 

497 

498 # condense the nested dict 

499 data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( 

500 data=data_nested, 

501 condense_numeric_keys=condense_numeric_keys, 

502 condense_matching_values=condense_matching_values, 

503 val_condense_fallback_mapping=val_condense_fallback_mapping, 

504 ) 

505 

506 # return in the specified format 

507 fmt_lower: str = fmt.lower() 

508 if fmt_lower == "dict": 

509 return data_condensed 

510 elif fmt_lower == "json": 

511 import json 

512 

513 return json.dumps(data_condensed, indent=2) 

514 elif fmt_lower in ["yaml", "yml"]: 

515 try: 

516 import yaml # type: ignore[import-untyped] 

517 

518 return yaml.dump(data_condensed, sort_keys=False) 

519 except ImportError as e: 

520 raise ValueError("PyYAML is required for YAML output") from e 

521 else: 

522 raise ValueError(f"Invalid return format: {fmt}")