Coverage for muutils/tensor_utils.py: 86%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""utilities for working with tensors and arrays. 

2 

3notably: 

4 

5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 

6- `DTYPE_MAP` mapping string representations of types to their type 

7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 

8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 

9 

10""" 

11 

12from __future__ import annotations 

13 

14import json 

15import typing 

16 

17import jaxtyping 

18import numpy as np 

19import torch 

20 

21from muutils.errormode import ErrorMode 

22from muutils.dictmagic import dotlist_to_nested_dict 

23 

24# pylint: disable=missing-class-docstring 

25 

26 

27TYPE_TO_JAX_DTYPE: dict = { 

28 float: jaxtyping.Float, 

29 int: jaxtyping.Int, 

30 jaxtyping.Float: jaxtyping.Float, 

31 jaxtyping.Int: jaxtyping.Int, 

32 # bool 

33 bool: jaxtyping.Bool, 

34 jaxtyping.Bool: jaxtyping.Bool, 

35 np.bool_: jaxtyping.Bool, 

36 torch.bool: jaxtyping.Bool, 

37 # numpy float 

38 np.float16: jaxtyping.Float, 

39 np.float32: jaxtyping.Float, 

40 np.float64: jaxtyping.Float, 

41 np.half: jaxtyping.Float, 

42 np.single: jaxtyping.Float, 

43 np.double: jaxtyping.Float, 

44 # numpy int 

45 np.int8: jaxtyping.Int, 

46 np.int16: jaxtyping.Int, 

47 np.int32: jaxtyping.Int, 

48 np.int64: jaxtyping.Int, 

49 np.longlong: jaxtyping.Int, 

50 np.short: jaxtyping.Int, 

51 np.uint8: jaxtyping.Int, 

52 # torch float 

53 torch.float: jaxtyping.Float, 

54 torch.float16: jaxtyping.Float, 

55 torch.float32: jaxtyping.Float, 

56 torch.float64: jaxtyping.Float, 

57 torch.half: jaxtyping.Float, 

58 torch.double: jaxtyping.Float, 

59 torch.bfloat16: jaxtyping.Float, 

60 # torch int 

61 torch.int: jaxtyping.Int, 

62 torch.int8: jaxtyping.Int, 

63 torch.int16: jaxtyping.Int, 

64 torch.int32: jaxtyping.Int, 

65 torch.int64: jaxtyping.Int, 

66 torch.long: jaxtyping.Int, 

67 torch.short: jaxtyping.Int, 

68} 

69"dict mapping python, numpy, and torch types to `jaxtyping` types" 

70 

71if np.version.version < "2.0.0": 

72 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float 

73 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int 

74 

75 

76# TODO: add proper type annotations to this signature 

77# TODO: maybe get rid of this altogether? 

78def jaxtype_factory( 

79 name: str, 

80 array_type: type, 

81 default_jax_dtype=jaxtyping.Float, 

82 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 

83) -> type: 

84 """usage: 

85 ``` 

86 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 

87 x: ATensor["dim1 dim2", np.float32] 

88 ``` 

89 """ 

90 legacy_mode_ = ErrorMode.from_any(legacy_mode) 

91 

92 class _BaseArray: 

93 """jaxtyping shorthand 

94 (backwards compatible with older versions of muutils.tensor_utils) 

95 

96 default_jax_dtype = {default_jax_dtype} 

97 array_type = {array_type} 

98 """ 

99 

100 def __new__(cls, *args, **kwargs): 

101 raise TypeError("Type FArray cannot be instantiated.") 

102 

103 def __init_subclass__(cls, *args, **kwargs): 

104 raise TypeError(f"Cannot subclass {cls.__name__}") 

105 

106 @classmethod 

107 def param_info(cls, params) -> str: 

108 """useful for error printing""" 

109 return "\n".join( 

110 f"{k} = {v}" 

111 for k, v in { 

112 "cls.__name__": cls.__name__, 

113 "cls.__doc__": cls.__doc__, 

114 "params": params, 

115 "type(params)": type(params), 

116 }.items() 

117 ) 

118 

119 @typing._tp_cache # type: ignore 

120 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 

121 # MyTensor["dim1 dim2"] 

122 if isinstance(params, str): 

123 return default_jax_dtype[array_type, params] 

124 

125 elif isinstance(params, tuple): 

126 if len(params) != 2: 

127 raise Exception( 

128 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 

129 ) 

130 

131 if isinstance(params[0], str): 

132 # MyTensor["dim1 dim2", int] 

133 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 

134 

135 elif isinstance(params[0], tuple): 

136 legacy_mode_.process( 

137 f"legacy type annotation was used:\n{cls.param_info(params) = }", 

138 except_cls=Exception, 

139 ) 

140 # MyTensor[("dim1", "dim2"), int] 

141 shape_anot: list[str] = list() 

142 for x in params[0]: 

143 if isinstance(x, str): 

144 shape_anot.append(x) 

145 elif isinstance(x, int): 

146 shape_anot.append(str(x)) 

147 elif isinstance(x, tuple): 

148 shape_anot.append("".join(str(y) for y in x)) 

149 else: 

150 raise Exception( 

151 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 

152 ) 

153 

154 return TYPE_TO_JAX_DTYPE[params[1]][ 

155 array_type, " ".join(shape_anot) 

156 ] 

157 else: 

158 raise Exception( 

159 f"unexpected type for params:\n{cls.param_info(params)}" 

160 ) 

161 

162 _BaseArray.__name__ = name 

163 

164 if _BaseArray.__doc__ is None: 

165 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 

166 

167 _BaseArray.__doc__ = _BaseArray.__doc__.format( 

168 default_jax_dtype=repr(default_jax_dtype), 

169 array_type=repr(array_type), 

170 ) 

171 

172 return _BaseArray 

173 

174 

175if typing.TYPE_CHECKING: 

176 # these class definitions are only used here to make pylint happy, 

177 # but they make mypy unhappy and there is no way to only run if not mypy 

178 # so, later on we have more ignores 

179 class ATensor(torch.Tensor): 

180 @typing._tp_cache # type: ignore 

181 def __class_getitem__(cls, params): 

182 raise NotImplementedError() 

183 

184 class NDArray(torch.Tensor): 

185 @typing._tp_cache # type: ignore 

186 def __class_getitem__(cls, params): 

187 raise NotImplementedError() 

188 

189 

190ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 

191 

192NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 

193 

194 

195def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 

196 """convert numpy dtype to torch dtype""" 

197 if isinstance(dtype, torch.dtype): 

198 return dtype 

199 else: 

200 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 

201 

202 

203DTYPE_LIST: list = [ 

204 *[ 

205 bool, 

206 int, 

207 float, 

208 ], 

209 *[ 

210 # ---------- 

211 # pytorch 

212 # ---------- 

213 # floats 

214 torch.float, 

215 torch.float32, 

216 torch.float64, 

217 torch.half, 

218 torch.double, 

219 torch.bfloat16, 

220 # complex 

221 torch.complex64, 

222 torch.complex128, 

223 # ints 

224 torch.int, 

225 torch.int8, 

226 torch.int16, 

227 torch.int32, 

228 torch.int64, 

229 torch.long, 

230 torch.short, 

231 # simplest 

232 torch.uint8, 

233 torch.bool, 

234 ], 

235 *[ 

236 # ---------- 

237 # numpy 

238 # ---------- 

239 # floats 

240 np.float16, 

241 np.float32, 

242 np.float64, 

243 np.half, 

244 np.single, 

245 np.double, 

246 # complex 

247 np.complex64, 

248 np.complex128, 

249 # ints 

250 np.int8, 

251 np.int16, 

252 np.int32, 

253 np.int64, 

254 np.longlong, 

255 np.short, 

256 # simplest 

257 np.uint8, 

258 np.bool_, 

259 ], 

260] 

261"list of all the python, numpy, and torch numerical types I could think of" 

262 

263if np.version.version < "2.0.0": 

264 DTYPE_LIST.extend([np.float_, np.int_]) 

265 

266DTYPE_MAP: dict = { 

267 **{str(x): x for x in DTYPE_LIST}, 

268 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 

269} 

270"mapping from string representations of types to their type" 

271 

272TORCH_DTYPE_MAP: dict = { 

273 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 

274} 

275"mapping from string representations of types to specifically torch types" 

276 

277# no idea why we have to do this, smh 

278DTYPE_MAP["bool"] = np.bool_ 

279TORCH_DTYPE_MAP["bool"] = torch.bool 

280 

281 

282TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 

283 "Adagrad": torch.optim.Adagrad, 

284 "Adam": torch.optim.Adam, 

285 "AdamW": torch.optim.AdamW, 

286 "SparseAdam": torch.optim.SparseAdam, 

287 "Adamax": torch.optim.Adamax, 

288 "ASGD": torch.optim.ASGD, 

289 "LBFGS": torch.optim.LBFGS, 

290 "NAdam": torch.optim.NAdam, 

291 "RAdam": torch.optim.RAdam, 

292 "RMSprop": torch.optim.RMSprop, 

293 "Rprop": torch.optim.Rprop, 

294 "SGD": torch.optim.SGD, 

295} 

296 

297 

298def pad_tensor( 

299 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 

300 padded_length: int, 

301 pad_value: float = 0.0, 

302 rpad: bool = False, 

303) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 

304 """pad a 1-d tensor on the left with pad_value to length `padded_length` 

305 

306 set `rpad = True` to pad on the right instead""" 

307 

308 temp: list[torch.Tensor] = [ 

309 torch.full( 

310 (padded_length - tensor.shape[0],), 

311 pad_value, 

312 dtype=tensor.dtype, 

313 device=tensor.device, 

314 ), 

315 tensor, 

316 ] 

317 

318 if rpad: 

319 temp.reverse() 

320 

321 return torch.cat(temp) 

322 

323 

324def lpad_tensor( 

325 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 

326) -> torch.Tensor: 

327 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 

328 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 

329 

330 

331def rpad_tensor( 

332 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 

333) -> torch.Tensor: 

334 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 

335 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 

336 

337 

338def pad_array( 

339 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 

340 padded_length: int, 

341 pad_value: float = 0.0, 

342 rpad: bool = False, 

343) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 

344 """pad a 1-d array on the left with pad_value to length `padded_length` 

345 

346 set `rpad = True` to pad on the right instead""" 

347 

348 temp: list[np.ndarray] = [ 

349 np.full( 

350 (padded_length - array.shape[0],), 

351 pad_value, 

352 dtype=array.dtype, 

353 ), 

354 array, 

355 ] 

356 

357 if rpad: 

358 temp.reverse() 

359 

360 return np.concatenate(temp) 

361 

362 

363def lpad_array( 

364 array: np.ndarray, padded_length: int, pad_value: float = 0.0 

365) -> np.ndarray: 

366 """pad a 1-d array on the left with pad_value to length `padded_length`""" 

367 return pad_array(array, padded_length, pad_value, rpad=False) 

368 

369 

370def rpad_array( 

371 array: np.ndarray, pad_length: int, pad_value: float = 0.0 

372) -> np.ndarray: 

373 """pad a 1-d array on the right with pad_value to length `pad_length`""" 

374 return pad_array(array, pad_length, pad_value, rpad=True) 

375 

376 

377def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 

378 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 

379 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 

380 

381 

382def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 

383 """printable version of get_dict_shapes""" 

384 return json.dumps( 

385 dotlist_to_nested_dict( 

386 { 

387 k: str( 

388 tuple(v.shape) 

389 ) # to string, since indent wont play nice with tuples 

390 for k, v in d.items() 

391 } 

392 ), 

393 indent=2, 

394 ) 

395 

396 

397class StateDictCompareError(AssertionError): 

398 """raised when state dicts don't match""" 

399 

400 pass 

401 

402 

403class StateDictKeysError(StateDictCompareError): 

404 """raised when state dict keys don't match""" 

405 

406 pass 

407 

408 

409class StateDictShapeError(StateDictCompareError): 

410 """raised when state dict shapes don't match""" 

411 

412 pass 

413 

414 

415class StateDictValueError(StateDictCompareError): 

416 """raised when state dict values don't match""" 

417 

418 pass 

419 

420 

421def compare_state_dicts( 

422 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 

423) -> None: 

424 """compare two dicts of tensors 

425 

426 # Parameters: 

427 

428 - `d1 : dict` 

429 - `d2 : dict` 

430 - `rtol : float` 

431 (defaults to `1e-5`) 

432 - `atol : float` 

433 (defaults to `1e-8`) 

434 - `verbose : bool` 

435 (defaults to `True`) 

436 

437 # Raises: 

438 

439 - `StateDictKeysError` : keys don't match 

440 - `StateDictShapeError` : shapes don't match (but keys do) 

441 - `StateDictValueError` : values don't match (but keys and shapes do) 

442 """ 

443 # check keys match 

444 d1_keys: set = set(d1.keys()) 

445 d2_keys: set = set(d2.keys()) 

446 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 

447 keys_diff_1: set = d1_keys - d2_keys 

448 keys_diff_2: set = d2_keys - d1_keys 

449 # sort sets for easier debugging 

450 symmetric_diff = set(sorted(symmetric_diff)) 

451 keys_diff_1 = set(sorted(keys_diff_1)) 

452 keys_diff_2 = set(sorted(keys_diff_2)) 

453 diff_shapes_1: str = ( 

454 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 

455 if verbose 

456 else "(verbose = False)" 

457 ) 

458 diff_shapes_2: str = ( 

459 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 

460 if verbose 

461 else "(verbose = False)" 

462 ) 

463 if not len(symmetric_diff) == 0: 

464 raise StateDictKeysError( 

465 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 

466 ) 

467 

468 # check tensors match 

469 shape_failed: list[str] = list() 

470 vals_failed: list[str] = list() 

471 for k, v1 in d1.items(): 

472 v2 = d2[k] 

473 # check shapes first 

474 if not v1.shape == v2.shape: 

475 shape_failed.append(k) 

476 else: 

477 # if shapes match, check values 

478 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 

479 vals_failed.append(k) 

480 

481 str_shape_failed: str = ( 

482 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 

483 ) 

484 str_vals_failed: str = ( 

485 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 

486 ) 

487 

488 if not len(shape_failed) == 0: 

489 raise StateDictShapeError( 

490 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 

491 ) 

492 if not len(vals_failed) == 0: 

493 raise StateDictValueError( 

494 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 

495 )