Coverage for muutils / tensor_utils.py: 99%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1"""utilities for working with tensors and arrays. 

2 

3notably: 

4 

5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 

6- `DTYPE_MAP` mapping string representations of types to their type 

7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 

8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 

9 

10""" 

11 

12from __future__ import annotations 

13 

14import json 

15import typing 

16from typing import Any 

17 

18import jaxtyping 

19import numpy as np 

20import torch 

21 

22from muutils.dictmagic import dotlist_to_nested_dict 

23 

24# pylint: disable=missing-class-docstring 

25 

26 

27TYPE_TO_JAX_DTYPE: dict[Any, Any] = { 

28 float: jaxtyping.Float, 

29 int: jaxtyping.Int, 

30 jaxtyping.Float: jaxtyping.Float, 

31 jaxtyping.Int: jaxtyping.Int, 

32 # bool 

33 bool: jaxtyping.Bool, 

34 jaxtyping.Bool: jaxtyping.Bool, 

35 np.bool_: jaxtyping.Bool, 

36 torch.bool: jaxtyping.Bool, 

37 # numpy float 

38 np.float16: jaxtyping.Float, 

39 np.float32: jaxtyping.Float, 

40 np.float64: jaxtyping.Float, 

41 np.half: jaxtyping.Float, 

42 np.single: jaxtyping.Float, 

43 np.double: jaxtyping.Float, 

44 # numpy int 

45 np.int8: jaxtyping.Int, 

46 np.int16: jaxtyping.Int, 

47 np.int32: jaxtyping.Int, 

48 np.int64: jaxtyping.Int, 

49 np.longlong: jaxtyping.Int, 

50 np.short: jaxtyping.Int, 

51 np.uint8: jaxtyping.Int, 

52 # torch float 

53 torch.float: jaxtyping.Float, 

54 torch.float16: jaxtyping.Float, 

55 torch.float32: jaxtyping.Float, 

56 torch.float64: jaxtyping.Float, 

57 torch.half: jaxtyping.Float, 

58 torch.double: jaxtyping.Float, 

59 torch.bfloat16: jaxtyping.Float, 

60 # torch int 

61 torch.int: jaxtyping.Int, 

62 torch.int8: jaxtyping.Int, 

63 torch.int16: jaxtyping.Int, 

64 torch.int32: jaxtyping.Int, 

65 torch.int64: jaxtyping.Int, 

66 torch.long: jaxtyping.Int, 

67 torch.short: jaxtyping.Int, 

68} 

69"dict mapping python, numpy, and torch types to `jaxtyping` types" 

70 

71# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 

72# use try/except for backwards compatibility and type checker friendliness 

73try: 

74 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] 

75 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] 

76except AttributeError: 

77 pass # numpy 2.0+ removed these deprecated aliases 

78 

79 

80# TODO: add proper type annotations to this signature 

81# TODO: maybe get rid of this altogether? 

82# def jaxtype_factory( 

83# name: str, 

84# array_type: type, 

85# default_jax_dtype: type[jaxtyping.Float | jaxtyping.Int | jaxtyping.Bool] = jaxtyping.Float, 

86# legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 

87# ) -> type: 

88# """usage: 

89# ``` 

90# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 

91# x: ATensor["dim1 dim2", np.float32] 

92# ``` 

93# """ 

94# legacy_mode_ = ErrorMode.from_any(legacy_mode) 

95 

96# class _BaseArray: 

97# """jaxtyping shorthand 

98# (backwards compatible with older versions of muutils.tensor_utils) 

99 

100# default_jax_dtype = {default_jax_dtype} 

101# array_type = {array_type} 

102# """ 

103 

104# def __new__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: 

105# raise TypeError("Type FArray cannot be instantiated.") 

106 

107# def __init_subclass__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: 

108# raise TypeError(f"Cannot subclass {cls.__name__}") 

109 

110# @classmethod 

111# def param_info(cls, params: typing.Union[str, tuple[Any, ...]]) -> str: 

112# """useful for error printing""" 

113# return "\n".join( 

114# f"{k} = {v}" 

115# for k, v in { 

116# "cls.__name__": cls.__name__, 

117# "cls.__doc__": cls.__doc__, 

118# "params": params, 

119# "type(params)": type(params), 

120# }.items() 

121# ) 

122 

123# @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] 

124# def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: # type: ignore[misc] 

125# # MyTensor["dim1 dim2"] 

126# if isinstance(params, str): 

127# return default_jax_dtype[array_type, params] 

128 

129# elif isinstance(params, tuple): 

130# if len(params) != 2: 

131# raise Exception( 

132# f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 

133# ) 

134 

135# if isinstance(params[0], str): 

136# # MyTensor["dim1 dim2", int] 

137# return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 

138 

139# elif isinstance(params[0], tuple): 

140# legacy_mode_.process( 

141# f"legacy type annotation was used:\n{cls.param_info(params) = }", 

142# except_cls=Exception, 

143# ) 

144# # MyTensor[("dim1", "dim2"), int] 

145# shape_anot: list[str] = list() 

146# for x in params[0]: 

147# if isinstance(x, str): 

148# shape_anot.append(x) 

149# elif isinstance(x, int): 

150# shape_anot.append(str(x)) 

151# elif isinstance(x, tuple): 

152# shape_anot.append("".join(str(y) for y in x)) 

153# else: 

154# raise Exception( 

155# f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 

156# ) 

157 

158# return TYPE_TO_JAX_DTYPE[params[1]][ 

159# array_type, " ".join(shape_anot) 

160# ] 

161# else: 

162# raise Exception( 

163# f"unexpected type for params:\n{cls.param_info(params)}" 

164# ) 

165 

166# _BaseArray.__name__ = name 

167 

168# if _BaseArray.__doc__ is None: 

169# _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 

170 

171# _BaseArray.__doc__ = _BaseArray.__doc__.format( 

172# default_jax_dtype=repr(default_jax_dtype), 

173# array_type=repr(array_type), 

174# ) 

175 

176# return _BaseArray 

177 

178 

179if typing.TYPE_CHECKING: 

180 # these class definitions are only used here to make pylint happy, 

181 # but they make mypy unhappy and there is no way to only run if not mypy 

182 # so, later on we have more ignores 

183 class ATensor(torch.Tensor): 

184 @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] 

185 def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: 

186 raise NotImplementedError() 

187 

188 class NDArray(torch.Tensor): 

189 @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] 

190 def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: 

191 raise NotImplementedError() 

192 

193 

194# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 

195 

196# NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 

197 

198 

199def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 

200 """convert numpy dtype to torch dtype""" 

201 if isinstance(dtype, torch.dtype): 

202 return dtype 

203 else: 

204 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 

205 

206 

207DTYPE_LIST: list[Any] = [ 

208 *[ 

209 bool, 

210 int, 

211 float, 

212 ], 

213 *[ 

214 # ---------- 

215 # pytorch 

216 # ---------- 

217 # floats 

218 torch.float, 

219 torch.float32, 

220 torch.float64, 

221 torch.half, 

222 torch.double, 

223 torch.bfloat16, 

224 # complex 

225 torch.complex64, 

226 torch.complex128, 

227 # ints 

228 torch.int, 

229 torch.int8, 

230 torch.int16, 

231 torch.int32, 

232 torch.int64, 

233 torch.long, 

234 torch.short, 

235 # simplest 

236 torch.uint8, 

237 torch.bool, 

238 ], 

239 *[ 

240 # ---------- 

241 # numpy 

242 # ---------- 

243 # floats 

244 np.float16, 

245 np.float32, 

246 np.float64, 

247 np.half, 

248 np.single, 

249 np.double, 

250 # complex 

251 np.complex64, 

252 np.complex128, 

253 # ints 

254 np.int8, 

255 np.int16, 

256 np.int32, 

257 np.int64, 

258 np.longlong, 

259 np.short, 

260 # simplest 

261 np.uint8, 

262 np.bool_, 

263 ], 

264] 

265"list of all the python, numpy, and torch numerical types I could think of" 

266 

267# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 

268try: 

269 DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] 

270except AttributeError: 

271 pass # numpy 2.0+ removed these deprecated aliases 

272 

273DTYPE_MAP: dict[str, Any] = { 

274 **{str(x): x for x in DTYPE_LIST}, 

275 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 

276} 

277"mapping from string representations of types to their type" 

278 

279TORCH_DTYPE_MAP: dict[str, torch.dtype] = { 

280 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 

281} 

282"mapping from string representations of types to specifically torch types" 

283 

284# no idea why we have to do this, smh 

285DTYPE_MAP["bool"] = np.bool_ 

286TORCH_DTYPE_MAP["bool"] = torch.bool 

287 

288 

289TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 

290 "Adagrad": torch.optim.Adagrad, 

291 "Adam": torch.optim.Adam, 

292 "AdamW": torch.optim.AdamW, 

293 "SparseAdam": torch.optim.SparseAdam, 

294 "Adamax": torch.optim.Adamax, 

295 "ASGD": torch.optim.ASGD, 

296 "LBFGS": torch.optim.LBFGS, 

297 "NAdam": torch.optim.NAdam, 

298 "RAdam": torch.optim.RAdam, 

299 "RMSprop": torch.optim.RMSprop, 

300 "Rprop": torch.optim.Rprop, 

301 "SGD": torch.optim.SGD, 

302} 

303 

304 

305def pad_tensor( 

306 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 

307 padded_length: int, 

308 pad_value: float = 0.0, 

309 rpad: bool = False, 

310) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 

311 """pad a 1-d tensor on the left with pad_value to length `padded_length` 

312 

313 set `rpad = True` to pad on the right instead""" 

314 

315 temp: list[torch.Tensor] = [ 

316 torch.full( 

317 (padded_length - tensor.shape[0],), 

318 pad_value, 

319 dtype=tensor.dtype, 

320 device=tensor.device, 

321 ), 

322 tensor, 

323 ] 

324 

325 if rpad: 

326 temp.reverse() 

327 

328 return torch.cat(temp) 

329 

330 

331def lpad_tensor( 

332 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 

333) -> torch.Tensor: 

334 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 

335 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 

336 

337 

338def rpad_tensor( 

339 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 

340) -> torch.Tensor: 

341 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 

342 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 

343 

344 

345def pad_array( 

346 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 

347 padded_length: int, 

348 pad_value: float = 0.0, 

349 rpad: bool = False, 

350) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 

351 """pad a 1-d array on the left with pad_value to length `padded_length` 

352 

353 set `rpad = True` to pad on the right instead""" 

354 

355 temp: list[np.ndarray] = [ 

356 np.full( 

357 (padded_length - array.shape[0],), 

358 pad_value, 

359 dtype=array.dtype, 

360 ), 

361 array, 

362 ] 

363 

364 if rpad: 

365 temp.reverse() 

366 

367 return np.concatenate(temp) 

368 

369 

370def lpad_array( 

371 array: np.ndarray, padded_length: int, pad_value: float = 0.0 

372) -> np.ndarray: 

373 """pad a 1-d array on the left with pad_value to length `padded_length`""" 

374 return pad_array(array, padded_length, pad_value, rpad=False) 

375 

376 

377def rpad_array( 

378 array: np.ndarray, pad_length: int, pad_value: float = 0.0 

379) -> np.ndarray: 

380 """pad a 1-d array on the right with pad_value to length `pad_length`""" 

381 return pad_array(array, pad_length, pad_value, rpad=True) 

382 

383 

384def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 

385 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 

386 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 

387 

388 

389def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 

390 """printable version of get_dict_shapes""" 

391 return json.dumps( 

392 dotlist_to_nested_dict( 

393 { 

394 k: str( 

395 tuple(v.shape) 

396 ) # to string, since indent wont play nice with tuples 

397 for k, v in d.items() 

398 } 

399 ), 

400 indent=2, 

401 ) 

402 

403 

404class StateDictCompareError(AssertionError): 

405 """raised when state dicts don't match""" 

406 

407 pass 

408 

409 

410class StateDictKeysError(StateDictCompareError): 

411 """raised when state dict keys don't match""" 

412 

413 pass 

414 

415 

416class StateDictShapeError(StateDictCompareError): 

417 """raised when state dict shapes don't match""" 

418 

419 pass 

420 

421 

422class StateDictValueError(StateDictCompareError): 

423 """raised when state dict values don't match""" 

424 

425 pass 

426 

427 

428def compare_state_dicts( 

429 d1: dict[str, Any], 

430 d2: dict[str, Any], 

431 rtol: float = 1e-5, 

432 atol: float = 1e-8, 

433 verbose: bool = True, 

434) -> None: 

435 """compare two dicts of tensors 

436 

437 # Parameters: 

438 

439 - `d1 : dict` 

440 - `d2 : dict` 

441 - `rtol : float` 

442 (defaults to `1e-5`) 

443 - `atol : float` 

444 (defaults to `1e-8`) 

445 - `verbose : bool` 

446 (defaults to `True`) 

447 

448 # Raises: 

449 

450 - `StateDictKeysError` : keys don't match 

451 - `StateDictShapeError` : shapes don't match (but keys do) 

452 - `StateDictValueError` : values don't match (but keys and shapes do) 

453 """ 

454 # check keys match 

455 d1_keys: set[str] = set(d1.keys()) 

456 d2_keys: set[str] = set(d2.keys()) 

457 symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys) 

458 keys_diff_1: set[str] = d1_keys - d2_keys 

459 keys_diff_2: set[str] = d2_keys - d1_keys 

460 # sort sets for easier debugging 

461 symmetric_diff = set(sorted(symmetric_diff)) 

462 keys_diff_1 = set(sorted(keys_diff_1)) 

463 keys_diff_2 = set(sorted(keys_diff_2)) 

464 diff_shapes_1: str = ( 

465 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 

466 if verbose 

467 else "(verbose = False)" 

468 ) 

469 diff_shapes_2: str = ( 

470 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 

471 if verbose 

472 else "(verbose = False)" 

473 ) 

474 if not len(symmetric_diff) == 0: 

475 raise StateDictKeysError( 

476 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 

477 ) 

478 

479 # check tensors match 

480 shape_failed: list[str] = list() 

481 vals_failed: list[str] = list() 

482 for k, v1 in d1.items(): 

483 v2 = d2[k] 

484 # check shapes first 

485 if not v1.shape == v2.shape: 

486 shape_failed.append(k) 

487 else: 

488 # if shapes match, check values 

489 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 

490 vals_failed.append(k) 

491 

492 str_shape_failed: str = ( 

493 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 

494 ) 

495 str_vals_failed: str = ( 

496 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 

497 ) 

498 

499 if not len(shape_failed) == 0: 

500 raise StateDictShapeError( 

501 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 

502 ) 

503 if not len(vals_failed) == 0: 

504 raise StateDictValueError( 

505 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 

506 )