muutils.mlutils
miscellaneous utilities for ML pipelines
1"miscellaneous utilities for ML pipelines" 2 3from __future__ import annotations 4 5import json 6import os 7import random 8import typing 9import warnings 10from itertools import islice 11from pathlib import Path 12from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, Union 13 14ARRAY_IMPORTS: bool 15try: 16 import numpy as np 17 import torch 18 import torch.backends.mps 19 20 ARRAY_IMPORTS = True 21except ImportError as e: 22 warnings.warn( 23 f"Numpy or torch not installed. Array operations will not be available.\n{e}" 24 ) 25 ARRAY_IMPORTS = False 26 27DEFAULT_SEED: int = 42 28GLOBAL_SEED: int = DEFAULT_SEED 29 30 31def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": 32 """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" 33 if not ARRAY_IMPORTS: 34 raise ImportError( 35 "Numpy or torch not installed. Array operations will not be available." 36 ) 37 try: 38 # if device is given 39 assert torch, "Torch is not available, cannot get device" # pyright: ignore[reportPossiblyUnboundVariable] 40 if device is not None: 41 device = torch.device(device) 42 if any( 43 [ 44 torch.cuda.is_available() and device.type == "cuda", 45 torch.backends.mps.is_available() and device.type == "mps", 46 device.type == "cpu", 47 ] 48 ): 49 # if device is given and available 50 pass 51 else: 52 warnings.warn( 53 f"Specified device {device} is not available, falling back to CPU" 54 ) 55 return torch.device("cpu") 56 57 # no device given, infer from availability 58 else: 59 if torch.cuda.is_available(): 60 device = torch.device("cuda") 61 elif torch.backends.mps.is_available(): 62 device = torch.device("mps") 63 else: 64 device = torch.device("cpu") 65 66 # put a dummy tensor on the device to check if it is available 67 _dummy = torch.zeros(1, device=device) 68 69 return device 70 71 except Exception as e: 72 warnings.warn( 73 f"Error while getting device, falling back to CPU. Error: {e}", 74 RuntimeWarning, 75 ) 76 return torch.device("cpu") # pyright: ignore[reportPossiblyUnboundVariable] 77 78 79def set_reproducibility(seed: int = DEFAULT_SEED): 80 """ 81 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information. 82 83 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades 84 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance. 85 """ 86 global GLOBAL_SEED 87 88 GLOBAL_SEED = seed 89 90 random.seed(seed) 91 92 if ARRAY_IMPORTS: 93 try: 94 assert np, "Numpy is not available, cannot set seed for numpy" # pyright: ignore[reportPossiblyUnboundVariable] 95 np.random.seed(seed) 96 except Exception as e: 97 warnings.warn(f"Error while setting seed for numpy: {e}", RuntimeWarning) 98 99 try: 100 assert torch, "Torch is not available, cannot set seed for torch" # pyright: ignore[reportPossiblyUnboundVariable] 101 torch.manual_seed(seed) 102 103 torch.use_deterministic_algorithms(True) 104 except Exception as e: 105 warnings.warn(f"Error while setting seed for torch: {e}", RuntimeWarning) 106 107 # Ensure reproducibility for concurrent CUDA streams 108 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. 109 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 110 111 112T = TypeVar("T") 113 114 115def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]: 116 """Yield successive chunks from an iterator.""" 117 # https://stackoverflow.com/a/61435714 118 iterator = iter(it) 119 while chunk := list(islice(iterator, chunk_size)): 120 yield chunk 121 122 123def get_checkpoint_paths_for_run( 124 run_path: Path, 125 extension: typing.Literal["pt", "zanj"], 126 checkpoints_format: str = "checkpoints/model.iter_*.{extension}", 127) -> list[tuple[int, Path]]: 128 """get checkpoints of the format from the run_path 129 130 note that `checkpoints_format` should contain a glob pattern with: 131 - unresolved "{extension}" format term for the extension 132 - a wildcard for the iteration number 133 """ 134 135 assert run_path.is_dir(), ( 136 f"Model path {run_path} is not a directory (expect run directory, not model files)" 137 ) 138 139 return [ 140 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) 141 for checkpoint_path in sorted( 142 Path(run_path).glob(checkpoints_format.format(extension=extension)) 143 ) 144 ] 145 146 147F = TypeVar("F", bound=Callable[..., Any]) 148 149 150def register_method( 151 method_dict: dict[str, Callable[..., Any]], 152 custom_name: Optional[str] = None, 153) -> Callable[[F], F]: 154 """Decorator to add a method to the method_dict""" 155 156 def decorator(method: F) -> F: 157 method_name: str 158 if custom_name is None: 159 method_name_orig: str | None = getattr(method, "__name__", None) 160 if method_name_orig is None: 161 warnings.warn( 162 f"Method {method} does not have a name, using sanitized repr" 163 ) 164 from muutils.misc import sanitize_identifier 165 166 method_name = sanitize_identifier(repr(method)) 167 else: 168 method_name = method_name_orig 169 else: 170 method_name = custom_name 171 # TYPING: ty complains here 172 method.__name__ = custom_name # type: ignore[unresolved-attribute] 173 assert method_name not in method_dict, ( 174 f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" 175 ) 176 method_dict[method_name] = method 177 return method 178 179 return decorator 180 181 182def pprint_summary(summary: dict): 183 print(json.dumps(summary, indent=2))
ARRAY_IMPORTS: bool =
True
DEFAULT_SEED: int =
42
GLOBAL_SEED: int =
42
def
get_device(device: Union[str, torch.device, NoneType] = None) -> torch.device:
32def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": 33 """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" 34 if not ARRAY_IMPORTS: 35 raise ImportError( 36 "Numpy or torch not installed. Array operations will not be available." 37 ) 38 try: 39 # if device is given 40 assert torch, "Torch is not available, cannot get device" # pyright: ignore[reportPossiblyUnboundVariable] 41 if device is not None: 42 device = torch.device(device) 43 if any( 44 [ 45 torch.cuda.is_available() and device.type == "cuda", 46 torch.backends.mps.is_available() and device.type == "mps", 47 device.type == "cpu", 48 ] 49 ): 50 # if device is given and available 51 pass 52 else: 53 warnings.warn( 54 f"Specified device {device} is not available, falling back to CPU" 55 ) 56 return torch.device("cpu") 57 58 # no device given, infer from availability 59 else: 60 if torch.cuda.is_available(): 61 device = torch.device("cuda") 62 elif torch.backends.mps.is_available(): 63 device = torch.device("mps") 64 else: 65 device = torch.device("cpu") 66 67 # put a dummy tensor on the device to check if it is available 68 _dummy = torch.zeros(1, device=device) 69 70 return device 71 72 except Exception as e: 73 warnings.warn( 74 f"Error while getting device, falling back to CPU. Error: {e}", 75 RuntimeWarning, 76 ) 77 return torch.device("cpu") # pyright: ignore[reportPossiblyUnboundVariable]
Get the torch.device instance on which torch.Tensors should be allocated.
def
set_reproducibility(seed: int = 42):
80def set_reproducibility(seed: int = DEFAULT_SEED): 81 """ 82 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information. 83 84 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades 85 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance. 86 """ 87 global GLOBAL_SEED 88 89 GLOBAL_SEED = seed 90 91 random.seed(seed) 92 93 if ARRAY_IMPORTS: 94 try: 95 assert np, "Numpy is not available, cannot set seed for numpy" # pyright: ignore[reportPossiblyUnboundVariable] 96 np.random.seed(seed) 97 except Exception as e: 98 warnings.warn(f"Error while setting seed for numpy: {e}", RuntimeWarning) 99 100 try: 101 assert torch, "Torch is not available, cannot set seed for torch" # pyright: ignore[reportPossiblyUnboundVariable] 102 torch.manual_seed(seed) 103 104 torch.use_deterministic_algorithms(True) 105 except Exception as e: 106 warnings.warn(f"Error while setting seed for torch: {e}", RuntimeWarning) 107 108 # Ensure reproducibility for concurrent CUDA streams 109 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. 110 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
def
chunks(it: Iterable[~T], chunk_size: int) -> Generator[list[~T], Any, NoneType]:
116def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]: 117 """Yield successive chunks from an iterator.""" 118 # https://stackoverflow.com/a/61435714 119 iterator = iter(it) 120 while chunk := list(islice(iterator, chunk_size)): 121 yield chunk
Yield successive chunks from an iterator.
def
get_checkpoint_paths_for_run( run_path: pathlib._local.Path, extension: Literal['pt', 'zanj'], checkpoints_format: str = 'checkpoints/model.iter_*.{extension}') -> list[tuple[int, pathlib._local.Path]]:
124def get_checkpoint_paths_for_run( 125 run_path: Path, 126 extension: typing.Literal["pt", "zanj"], 127 checkpoints_format: str = "checkpoints/model.iter_*.{extension}", 128) -> list[tuple[int, Path]]: 129 """get checkpoints of the format from the run_path 130 131 note that `checkpoints_format` should contain a glob pattern with: 132 - unresolved "{extension}" format term for the extension 133 - a wildcard for the iteration number 134 """ 135 136 assert run_path.is_dir(), ( 137 f"Model path {run_path} is not a directory (expect run directory, not model files)" 138 ) 139 140 return [ 141 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) 142 for checkpoint_path in sorted( 143 Path(run_path).glob(checkpoints_format.format(extension=extension)) 144 ) 145 ]
get checkpoints of the format from the run_path
note that checkpoints_format should contain a glob pattern with:
- unresolved "{extension}" format term for the extension
- a wildcard for the iteration number
def
register_method( method_dict: dict[str, typing.Callable[..., typing.Any]], custom_name: Optional[str] = None) -> Callable[[~F], ~F]:
151def register_method( 152 method_dict: dict[str, Callable[..., Any]], 153 custom_name: Optional[str] = None, 154) -> Callable[[F], F]: 155 """Decorator to add a method to the method_dict""" 156 157 def decorator(method: F) -> F: 158 method_name: str 159 if custom_name is None: 160 method_name_orig: str | None = getattr(method, "__name__", None) 161 if method_name_orig is None: 162 warnings.warn( 163 f"Method {method} does not have a name, using sanitized repr" 164 ) 165 from muutils.misc import sanitize_identifier 166 167 method_name = sanitize_identifier(repr(method)) 168 else: 169 method_name = method_name_orig 170 else: 171 method_name = custom_name 172 # TYPING: ty complains here 173 method.__name__ = custom_name # type: ignore[unresolved-attribute] 174 assert method_name not in method_dict, ( 175 f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" 176 ) 177 method_dict[method_name] = method 178 return method 179 180 return decorator
Decorator to add a method to the method_dict
def
pprint_summary(summary: dict):