Coverage for muutils/mlutils.py: 39%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"miscellaneous utilities for ML pipelines" 

2 

3from __future__ import annotations 

4 

5import json 

6import os 

7import random 

8import typing 

9import warnings 

10from itertools import islice 

11from pathlib import Path 

12from typing import Any, Callable, Optional, TypeVar, Union 

13 

14ARRAY_IMPORTS: bool 

15try: 

16 import numpy as np 

17 import torch 

18 

19 ARRAY_IMPORTS = True 

20except ImportError as e: 

21 warnings.warn( 

22 f"Numpy or torch not installed. Array operations will not be available.\n{e}" 

23 ) 

24 ARRAY_IMPORTS = False 

25 

26DEFAULT_SEED: int = 42 

27GLOBAL_SEED: int = DEFAULT_SEED 

28 

29 

30def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": 

31 """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" 

32 if not ARRAY_IMPORTS: 

33 raise ImportError( 

34 "Numpy or torch not installed. Array operations will not be available." 

35 ) 

36 try: 

37 # if device is given 

38 if device is not None: 

39 device = torch.device(device) 

40 if any( 

41 [ 

42 torch.cuda.is_available() and device.type == "cuda", 

43 torch.backends.mps.is_available() and device.type == "mps", 

44 device.type == "cpu", 

45 ] 

46 ): 

47 # if device is given and available 

48 pass 

49 else: 

50 warnings.warn( 

51 f"Specified device {device} is not available, falling back to CPU" 

52 ) 

53 return torch.device("cpu") 

54 

55 # no device given, infer from availability 

56 else: 

57 if torch.cuda.is_available(): 

58 device = torch.device("cuda") 

59 elif torch.backends.mps.is_available(): 

60 device = torch.device("mps") 

61 else: 

62 device = torch.device("cpu") 

63 

64 # put a dummy tensor on the device to check if it is available 

65 _dummy = torch.zeros(1, device=device) 

66 

67 return device 

68 

69 except Exception as e: 

70 warnings.warn( 

71 f"Error while getting device, falling back to CPU. Error: {e}", 

72 RuntimeWarning, 

73 ) 

74 return torch.device("cpu") 

75 

76 

77def set_reproducibility(seed: int = DEFAULT_SEED): 

78 """ 

79 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information. 

80 

81 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades 

82 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance. 

83 """ 

84 global GLOBAL_SEED 

85 

86 GLOBAL_SEED = seed 

87 

88 random.seed(seed) 

89 

90 if ARRAY_IMPORTS: 

91 np.random.seed(seed) 

92 torch.manual_seed(seed) 

93 

94 torch.use_deterministic_algorithms(True) 

95 # Ensure reproducibility for concurrent CUDA streams 

96 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. 

97 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 

98 

99 

100def chunks(it, chunk_size): 

101 """Yield successive chunks from an iterator.""" 

102 # https://stackoverflow.com/a/61435714 

103 iterator = iter(it) 

104 while chunk := list(islice(iterator, chunk_size)): 

105 yield chunk 

106 

107 

108def get_checkpoint_paths_for_run( 

109 run_path: Path, 

110 extension: typing.Literal["pt", "zanj"], 

111 checkpoints_format: str = "checkpoints/model.iter_*.{extension}", 

112) -> list[tuple[int, Path]]: 

113 """get checkpoints of the format from the run_path 

114 

115 note that `checkpoints_format` should contain a glob pattern with: 

116 - unresolved "{extension}" format term for the extension 

117 - a wildcard for the iteration number 

118 """ 

119 

120 assert run_path.is_dir(), f"Model path {run_path} is not a directory (expect run directory, not model files)" 

121 

122 return [ 

123 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) 

124 for checkpoint_path in sorted( 

125 Path(run_path).glob(checkpoints_format.format(extension=extension)) 

126 ) 

127 ] 

128 

129 

130F = TypeVar("F", bound=Callable[..., Any]) 

131 

132 

133def register_method( 

134 method_dict: dict[str, Callable[..., Any]], 

135 custom_name: Optional[str] = None, 

136) -> Callable[[F], F]: 

137 """Decorator to add a method to the method_dict""" 

138 

139 def decorator(method: F) -> F: 

140 method_name: str 

141 if custom_name is None: 

142 method_name_orig: str | None = getattr(method, "__name__", None) 

143 if method_name_orig is None: 

144 warnings.warn( 

145 f"Method {method} does not have a name, using sanitized repr" 

146 ) 

147 from muutils.misc import sanitize_identifier 

148 

149 method_name = sanitize_identifier(repr(method)) 

150 else: 

151 method_name = method_name_orig 

152 else: 

153 method_name = custom_name 

154 method.__name__ = custom_name 

155 assert ( 

156 method_name not in method_dict 

157 ), f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" 

158 method_dict[method_name] = method 

159 return method 

160 

161 return decorator 

162 

163 

164def pprint_summary(summary: dict): 

165 print(json.dumps(summary, indent=2))