Coverage for muutils / mlutils.py: 40%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1"miscellaneous utilities for ML pipelines" 

2 

3from __future__ import annotations 

4 

5import json 

6import os 

7import random 

8import typing 

9import warnings 

10from itertools import islice 

11from pathlib import Path 

12from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, Union 

13 

14ARRAY_IMPORTS: bool 

15try: 

16 import numpy as np 

17 import torch 

18 import torch.backends.mps 

19 

20 ARRAY_IMPORTS = True 

21except ImportError as e: 

22 warnings.warn( 

23 f"Numpy or torch not installed. Array operations will not be available.\n{e}" 

24 ) 

25 ARRAY_IMPORTS = False 

26 

27DEFAULT_SEED: int = 42 

28GLOBAL_SEED: int = DEFAULT_SEED 

29 

30 

31def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": 

32 """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" 

33 if not ARRAY_IMPORTS: 

34 raise ImportError( 

35 "Numpy or torch not installed. Array operations will not be available." 

36 ) 

37 try: 

38 # if device is given 

39 assert torch, "Torch is not available, cannot get device" # pyright: ignore[reportPossiblyUnboundVariable] 

40 if device is not None: 

41 device = torch.device(device) 

42 if any( 

43 [ 

44 torch.cuda.is_available() and device.type == "cuda", 

45 torch.backends.mps.is_available() and device.type == "mps", 

46 device.type == "cpu", 

47 ] 

48 ): 

49 # if device is given and available 

50 pass 

51 else: 

52 warnings.warn( 

53 f"Specified device {device} is not available, falling back to CPU" 

54 ) 

55 return torch.device("cpu") 

56 

57 # no device given, infer from availability 

58 else: 

59 if torch.cuda.is_available(): 

60 device = torch.device("cuda") 

61 elif torch.backends.mps.is_available(): 

62 device = torch.device("mps") 

63 else: 

64 device = torch.device("cpu") 

65 

66 # put a dummy tensor on the device to check if it is available 

67 _dummy = torch.zeros(1, device=device) 

68 

69 return device 

70 

71 except Exception as e: 

72 warnings.warn( 

73 f"Error while getting device, falling back to CPU. Error: {e}", 

74 RuntimeWarning, 

75 ) 

76 return torch.device("cpu") # pyright: ignore[reportPossiblyUnboundVariable] 

77 

78 

79def set_reproducibility(seed: int = DEFAULT_SEED): 

80 """ 

81 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information. 

82 

83 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades 

84 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance. 

85 """ 

86 global GLOBAL_SEED 

87 

88 GLOBAL_SEED = seed 

89 

90 random.seed(seed) 

91 

92 if ARRAY_IMPORTS: 

93 try: 

94 assert np, "Numpy is not available, cannot set seed for numpy" # pyright: ignore[reportPossiblyUnboundVariable] 

95 np.random.seed(seed) 

96 except Exception as e: 

97 warnings.warn(f"Error while setting seed for numpy: {e}", RuntimeWarning) 

98 

99 try: 

100 assert torch, "Torch is not available, cannot set seed for torch" # pyright: ignore[reportPossiblyUnboundVariable] 

101 torch.manual_seed(seed) 

102 

103 torch.use_deterministic_algorithms(True) 

104 except Exception as e: 

105 warnings.warn(f"Error while setting seed for torch: {e}", RuntimeWarning) 

106 

107 # Ensure reproducibility for concurrent CUDA streams 

108 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. 

109 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 

110 

111 

112T = TypeVar("T") 

113 

114 

115def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]: 

116 """Yield successive chunks from an iterator.""" 

117 # https://stackoverflow.com/a/61435714 

118 iterator = iter(it) 

119 while chunk := list(islice(iterator, chunk_size)): 

120 yield chunk 

121 

122 

123def get_checkpoint_paths_for_run( 

124 run_path: Path, 

125 extension: typing.Literal["pt", "zanj"], 

126 checkpoints_format: str = "checkpoints/model.iter_*.{extension}", 

127) -> list[tuple[int, Path]]: 

128 """get checkpoints of the format from the run_path 

129 

130 note that `checkpoints_format` should contain a glob pattern with: 

131 - unresolved "{extension}" format term for the extension 

132 - a wildcard for the iteration number 

133 """ 

134 

135 assert run_path.is_dir(), ( 

136 f"Model path {run_path} is not a directory (expect run directory, not model files)" 

137 ) 

138 

139 return [ 

140 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) 

141 for checkpoint_path in sorted( 

142 Path(run_path).glob(checkpoints_format.format(extension=extension)) 

143 ) 

144 ] 

145 

146 

147F = TypeVar("F", bound=Callable[..., Any]) 

148 

149 

150def register_method( 

151 method_dict: dict[str, Callable[..., Any]], 

152 custom_name: Optional[str] = None, 

153) -> Callable[[F], F]: 

154 """Decorator to add a method to the method_dict""" 

155 

156 def decorator(method: F) -> F: 

157 method_name: str 

158 if custom_name is None: 

159 method_name_orig: str | None = getattr(method, "__name__", None) 

160 if method_name_orig is None: 

161 warnings.warn( 

162 f"Method {method} does not have a name, using sanitized repr" 

163 ) 

164 from muutils.misc import sanitize_identifier 

165 

166 method_name = sanitize_identifier(repr(method)) 

167 else: 

168 method_name = method_name_orig 

169 else: 

170 method_name = custom_name 

171 # TYPING: ty complains here 

172 method.__name__ = custom_name # type: ignore[unresolved-attribute] 

173 assert method_name not in method_dict, ( 

174 f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" 

175 ) 

176 method_dict[method_name] = method 

177 return method 

178 

179 return decorator 

180 

181 

182def pprint_summary(summary: dict): 

183 print(json.dumps(summary, indent=2))