Coverage for muutils/tensor_utils.py: 90%

124 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-28 17:24 +0000

1"""utilities for working with tensors and arrays. 

2 

3notably: 

4 

5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 

6- `DTYPE_MAP` mapping string representations of types to their type 

7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 

8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 

9 

10""" 

11 

12from __future__ import annotations 

13 

14import json 

15import typing 

16 

17import jaxtyping 

18import numpy as np 

19import torch 

20 

21from muutils.errormode import ErrorMode 

22from muutils.dictmagic import dotlist_to_nested_dict 

23 

24# pylint: disable=missing-class-docstring 

25 

26 

27TYPE_TO_JAX_DTYPE: dict = { 

28 float: jaxtyping.Float, 

29 int: jaxtyping.Int, 

30 jaxtyping.Float: jaxtyping.Float, 

31 jaxtyping.Int: jaxtyping.Int, 

32 # bool 

33 bool: jaxtyping.Bool, 

34 jaxtyping.Bool: jaxtyping.Bool, 

35 np.bool_: jaxtyping.Bool, 

36 torch.bool: jaxtyping.Bool, 

37 # numpy float 

38 np.float16: jaxtyping.Float, 

39 np.float32: jaxtyping.Float, 

40 np.float64: jaxtyping.Float, 

41 np.half: jaxtyping.Float, 

42 np.single: jaxtyping.Float, 

43 np.double: jaxtyping.Float, 

44 # numpy int 

45 np.int8: jaxtyping.Int, 

46 np.int16: jaxtyping.Int, 

47 np.int32: jaxtyping.Int, 

48 np.int64: jaxtyping.Int, 

49 np.longlong: jaxtyping.Int, 

50 np.short: jaxtyping.Int, 

51 np.uint8: jaxtyping.Int, 

52 # torch float 

53 torch.float: jaxtyping.Float, 

54 torch.float16: jaxtyping.Float, 

55 torch.float32: jaxtyping.Float, 

56 torch.float64: jaxtyping.Float, 

57 torch.half: jaxtyping.Float, 

58 torch.double: jaxtyping.Float, 

59 torch.bfloat16: jaxtyping.Float, 

60 # torch int 

61 torch.int: jaxtyping.Int, 

62 torch.int8: jaxtyping.Int, 

63 torch.int16: jaxtyping.Int, 

64 torch.int32: jaxtyping.Int, 

65 torch.int64: jaxtyping.Int, 

66 torch.long: jaxtyping.Int, 

67 torch.short: jaxtyping.Int, 

68} 

69"dict mapping python, numpy, and torch types to `jaxtyping` types" 

70 

71# we check for version here, so it shouldn't error 

72if np.version.version < "2.0.0": 

73 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] 

74 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] 

75 

76 

77# TODO: add proper type annotations to this signature 

78# TODO: maybe get rid of this altogether? 

79def jaxtype_factory( 

80 name: str, 

81 array_type: type, 

82 default_jax_dtype=jaxtyping.Float, 

83 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 

84) -> type: 

85 """usage: 

86 ``` 

87 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 

88 x: ATensor["dim1 dim2", np.float32] 

89 ``` 

90 """ 

91 legacy_mode_ = ErrorMode.from_any(legacy_mode) 

92 

93 class _BaseArray: 

94 """jaxtyping shorthand 

95 (backwards compatible with older versions of muutils.tensor_utils) 

96 

97 default_jax_dtype = {default_jax_dtype} 

98 array_type = {array_type} 

99 """ 

100 

101 def __new__(cls, *args, **kwargs): 

102 raise TypeError("Type FArray cannot be instantiated.") 

103 

104 def __init_subclass__(cls, *args, **kwargs): 

105 raise TypeError(f"Cannot subclass {cls.__name__}") 

106 

107 @classmethod 

108 def param_info(cls, params) -> str: 

109 """useful for error printing""" 

110 return "\n".join( 

111 f"{k} = {v}" 

112 for k, v in { 

113 "cls.__name__": cls.__name__, 

114 "cls.__doc__": cls.__doc__, 

115 "params": params, 

116 "type(params)": type(params), 

117 }.items() 

118 ) 

119 

120 @typing._tp_cache # type: ignore 

121 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 

122 # MyTensor["dim1 dim2"] 

123 if isinstance(params, str): 

124 return default_jax_dtype[array_type, params] 

125 

126 elif isinstance(params, tuple): 

127 if len(params) != 2: 

128 raise Exception( 

129 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 

130 ) 

131 

132 if isinstance(params[0], str): 

133 # MyTensor["dim1 dim2", int] 

134 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 

135 

136 elif isinstance(params[0], tuple): 

137 legacy_mode_.process( 

138 f"legacy type annotation was used:\n{cls.param_info(params) = }", 

139 except_cls=Exception, 

140 ) 

141 # MyTensor[("dim1", "dim2"), int] 

142 shape_anot: list[str] = list() 

143 for x in params[0]: 

144 if isinstance(x, str): 

145 shape_anot.append(x) 

146 elif isinstance(x, int): 

147 shape_anot.append(str(x)) 

148 elif isinstance(x, tuple): 

149 shape_anot.append("".join(str(y) for y in x)) 

150 else: 

151 raise Exception( 

152 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 

153 ) 

154 

155 return TYPE_TO_JAX_DTYPE[params[1]][ 

156 array_type, " ".join(shape_anot) 

157 ] 

158 else: 

159 raise Exception( 

160 f"unexpected type for params:\n{cls.param_info(params)}" 

161 ) 

162 

163 _BaseArray.__name__ = name 

164 

165 if _BaseArray.__doc__ is None: 

166 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 

167 

168 _BaseArray.__doc__ = _BaseArray.__doc__.format( 

169 default_jax_dtype=repr(default_jax_dtype), 

170 array_type=repr(array_type), 

171 ) 

172 

173 return _BaseArray 

174 

175 

176if typing.TYPE_CHECKING: 

177 # these class definitions are only used here to make pylint happy, 

178 # but they make mypy unhappy and there is no way to only run if not mypy 

179 # so, later on we have more ignores 

180 class ATensor(torch.Tensor): 

181 @typing._tp_cache # type: ignore 

182 def __class_getitem__(cls, params): 

183 raise NotImplementedError() 

184 

185 class NDArray(torch.Tensor): 

186 @typing._tp_cache # type: ignore 

187 def __class_getitem__(cls, params): 

188 raise NotImplementedError() 

189 

190 

191ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 

192 

193NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 

194 

195 

196def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 

197 """convert numpy dtype to torch dtype""" 

198 if isinstance(dtype, torch.dtype): 

199 return dtype 

200 else: 

201 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 

202 

203 

204DTYPE_LIST: list = [ 

205 *[ 

206 bool, 

207 int, 

208 float, 

209 ], 

210 *[ 

211 # ---------- 

212 # pytorch 

213 # ---------- 

214 # floats 

215 torch.float, 

216 torch.float32, 

217 torch.float64, 

218 torch.half, 

219 torch.double, 

220 torch.bfloat16, 

221 # complex 

222 torch.complex64, 

223 torch.complex128, 

224 # ints 

225 torch.int, 

226 torch.int8, 

227 torch.int16, 

228 torch.int32, 

229 torch.int64, 

230 torch.long, 

231 torch.short, 

232 # simplest 

233 torch.uint8, 

234 torch.bool, 

235 ], 

236 *[ 

237 # ---------- 

238 # numpy 

239 # ---------- 

240 # floats 

241 np.float16, 

242 np.float32, 

243 np.float64, 

244 np.half, 

245 np.single, 

246 np.double, 

247 # complex 

248 np.complex64, 

249 np.complex128, 

250 # ints 

251 np.int8, 

252 np.int16, 

253 np.int32, 

254 np.int64, 

255 np.longlong, 

256 np.short, 

257 # simplest 

258 np.uint8, 

259 np.bool_, 

260 ], 

261] 

262"list of all the python, numpy, and torch numerical types I could think of" 

263 

264if np.version.version < "2.0.0": 

265 DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] 

266 

267DTYPE_MAP: dict = { 

268 **{str(x): x for x in DTYPE_LIST}, 

269 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 

270} 

271"mapping from string representations of types to their type" 

272 

273TORCH_DTYPE_MAP: dict = { 

274 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 

275} 

276"mapping from string representations of types to specifically torch types" 

277 

278# no idea why we have to do this, smh 

279DTYPE_MAP["bool"] = np.bool_ 

280TORCH_DTYPE_MAP["bool"] = torch.bool 

281 

282 

283TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 

284 "Adagrad": torch.optim.Adagrad, 

285 "Adam": torch.optim.Adam, 

286 "AdamW": torch.optim.AdamW, 

287 "SparseAdam": torch.optim.SparseAdam, 

288 "Adamax": torch.optim.Adamax, 

289 "ASGD": torch.optim.ASGD, 

290 "LBFGS": torch.optim.LBFGS, 

291 "NAdam": torch.optim.NAdam, 

292 "RAdam": torch.optim.RAdam, 

293 "RMSprop": torch.optim.RMSprop, 

294 "Rprop": torch.optim.Rprop, 

295 "SGD": torch.optim.SGD, 

296} 

297 

298 

299def pad_tensor( 

300 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 

301 padded_length: int, 

302 pad_value: float = 0.0, 

303 rpad: bool = False, 

304) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 

305 """pad a 1-d tensor on the left with pad_value to length `padded_length` 

306 

307 set `rpad = True` to pad on the right instead""" 

308 

309 temp: list[torch.Tensor] = [ 

310 torch.full( 

311 (padded_length - tensor.shape[0],), 

312 pad_value, 

313 dtype=tensor.dtype, 

314 device=tensor.device, 

315 ), 

316 tensor, 

317 ] 

318 

319 if rpad: 

320 temp.reverse() 

321 

322 return torch.cat(temp) 

323 

324 

325def lpad_tensor( 

326 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 

327) -> torch.Tensor: 

328 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 

329 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 

330 

331 

332def rpad_tensor( 

333 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 

334) -> torch.Tensor: 

335 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 

336 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 

337 

338 

339def pad_array( 

340 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 

341 padded_length: int, 

342 pad_value: float = 0.0, 

343 rpad: bool = False, 

344) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 

345 """pad a 1-d array on the left with pad_value to length `padded_length` 

346 

347 set `rpad = True` to pad on the right instead""" 

348 

349 temp: list[np.ndarray] = [ 

350 np.full( 

351 (padded_length - array.shape[0],), 

352 pad_value, 

353 dtype=array.dtype, 

354 ), 

355 array, 

356 ] 

357 

358 if rpad: 

359 temp.reverse() 

360 

361 return np.concatenate(temp) 

362 

363 

364def lpad_array( 

365 array: np.ndarray, padded_length: int, pad_value: float = 0.0 

366) -> np.ndarray: 

367 """pad a 1-d array on the left with pad_value to length `padded_length`""" 

368 return pad_array(array, padded_length, pad_value, rpad=False) 

369 

370 

371def rpad_array( 

372 array: np.ndarray, pad_length: int, pad_value: float = 0.0 

373) -> np.ndarray: 

374 """pad a 1-d array on the right with pad_value to length `pad_length`""" 

375 return pad_array(array, pad_length, pad_value, rpad=True) 

376 

377 

378def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 

379 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 

380 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 

381 

382 

383def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 

384 """printable version of get_dict_shapes""" 

385 return json.dumps( 

386 dotlist_to_nested_dict( 

387 { 

388 k: str( 

389 tuple(v.shape) 

390 ) # to string, since indent wont play nice with tuples 

391 for k, v in d.items() 

392 } 

393 ), 

394 indent=2, 

395 ) 

396 

397 

398class StateDictCompareError(AssertionError): 

399 """raised when state dicts don't match""" 

400 

401 pass 

402 

403 

404class StateDictKeysError(StateDictCompareError): 

405 """raised when state dict keys don't match""" 

406 

407 pass 

408 

409 

410class StateDictShapeError(StateDictCompareError): 

411 """raised when state dict shapes don't match""" 

412 

413 pass 

414 

415 

416class StateDictValueError(StateDictCompareError): 

417 """raised when state dict values don't match""" 

418 

419 pass 

420 

421 

422def compare_state_dicts( 

423 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 

424) -> None: 

425 """compare two dicts of tensors 

426 

427 # Parameters: 

428 

429 - `d1 : dict` 

430 - `d2 : dict` 

431 - `rtol : float` 

432 (defaults to `1e-5`) 

433 - `atol : float` 

434 (defaults to `1e-8`) 

435 - `verbose : bool` 

436 (defaults to `True`) 

437 

438 # Raises: 

439 

440 - `StateDictKeysError` : keys don't match 

441 - `StateDictShapeError` : shapes don't match (but keys do) 

442 - `StateDictValueError` : values don't match (but keys and shapes do) 

443 """ 

444 # check keys match 

445 d1_keys: set = set(d1.keys()) 

446 d2_keys: set = set(d2.keys()) 

447 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 

448 keys_diff_1: set = d1_keys - d2_keys 

449 keys_diff_2: set = d2_keys - d1_keys 

450 # sort sets for easier debugging 

451 symmetric_diff = set(sorted(symmetric_diff)) 

452 keys_diff_1 = set(sorted(keys_diff_1)) 

453 keys_diff_2 = set(sorted(keys_diff_2)) 

454 diff_shapes_1: str = ( 

455 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 

456 if verbose 

457 else "(verbose = False)" 

458 ) 

459 diff_shapes_2: str = ( 

460 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 

461 if verbose 

462 else "(verbose = False)" 

463 ) 

464 if not len(symmetric_diff) == 0: 

465 raise StateDictKeysError( 

466 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 

467 ) 

468 

469 # check tensors match 

470 shape_failed: list[str] = list() 

471 vals_failed: list[str] = list() 

472 for k, v1 in d1.items(): 

473 v2 = d2[k] 

474 # check shapes first 

475 if not v1.shape == v2.shape: 

476 shape_failed.append(k) 

477 else: 

478 # if shapes match, check values 

479 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 

480 vals_failed.append(k) 

481 

482 str_shape_failed: str = ( 

483 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 

484 ) 

485 str_vals_failed: str = ( 

486 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 

487 ) 

488 

489 if not len(shape_failed) == 0: 

490 raise StateDictShapeError( 

491 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 

492 ) 

493 if not len(vals_failed) == 0: 

494 raise StateDictValueError( 

495 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 

496 )