Coverage for muutils/mlutils.py: 39%
72 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"miscellaneous utilities for ML pipelines"
3from __future__ import annotations
5import json
6import os
7import random
8import typing
9import warnings
10from itertools import islice
11from pathlib import Path
12from typing import Any, Callable, Optional, TypeVar, Union
14ARRAY_IMPORTS: bool
15try:
16 import numpy as np
17 import torch
19 ARRAY_IMPORTS = True
20except ImportError as e:
21 warnings.warn(
22 f"Numpy or torch not installed. Array operations will not be available.\n{e}"
23 )
24 ARRAY_IMPORTS = False
26DEFAULT_SEED: int = 42
27GLOBAL_SEED: int = DEFAULT_SEED
30def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
31 """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
32 if not ARRAY_IMPORTS:
33 raise ImportError(
34 "Numpy or torch not installed. Array operations will not be available."
35 )
36 try:
37 # if device is given
38 if device is not None:
39 device = torch.device(device)
40 if any(
41 [
42 torch.cuda.is_available() and device.type == "cuda",
43 torch.backends.mps.is_available() and device.type == "mps",
44 device.type == "cpu",
45 ]
46 ):
47 # if device is given and available
48 pass
49 else:
50 warnings.warn(
51 f"Specified device {device} is not available, falling back to CPU"
52 )
53 return torch.device("cpu")
55 # no device given, infer from availability
56 else:
57 if torch.cuda.is_available():
58 device = torch.device("cuda")
59 elif torch.backends.mps.is_available():
60 device = torch.device("mps")
61 else:
62 device = torch.device("cpu")
64 # put a dummy tensor on the device to check if it is available
65 _dummy = torch.zeros(1, device=device)
67 return device
69 except Exception as e:
70 warnings.warn(
71 f"Error while getting device, falling back to CPU. Error: {e}",
72 RuntimeWarning,
73 )
74 return torch.device("cpu")
77def set_reproducibility(seed: int = DEFAULT_SEED):
78 """
79 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
81 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
82 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
83 """
84 global GLOBAL_SEED
86 GLOBAL_SEED = seed
88 random.seed(seed)
90 if ARRAY_IMPORTS:
91 np.random.seed(seed)
92 torch.manual_seed(seed)
94 torch.use_deterministic_algorithms(True)
95 # Ensure reproducibility for concurrent CUDA streams
96 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
97 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
100def chunks(it, chunk_size):
101 """Yield successive chunks from an iterator."""
102 # https://stackoverflow.com/a/61435714
103 iterator = iter(it)
104 while chunk := list(islice(iterator, chunk_size)):
105 yield chunk
108def get_checkpoint_paths_for_run(
109 run_path: Path,
110 extension: typing.Literal["pt", "zanj"],
111 checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
112) -> list[tuple[int, Path]]:
113 """get checkpoints of the format from the run_path
115 note that `checkpoints_format` should contain a glob pattern with:
116 - unresolved "{extension}" format term for the extension
117 - a wildcard for the iteration number
118 """
120 assert run_path.is_dir(), f"Model path {run_path} is not a directory (expect run directory, not model files)"
122 return [
123 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
124 for checkpoint_path in sorted(
125 Path(run_path).glob(checkpoints_format.format(extension=extension))
126 )
127 ]
130F = TypeVar("F", bound=Callable[..., Any])
133def register_method(
134 method_dict: dict[str, Callable[..., Any]],
135 custom_name: Optional[str] = None,
136) -> Callable[[F], F]:
137 """Decorator to add a method to the method_dict"""
139 def decorator(method: F) -> F:
140 method_name: str
141 if custom_name is None:
142 method_name_orig: str | None = getattr(method, "__name__", None)
143 if method_name_orig is None:
144 warnings.warn(
145 f"Method {method} does not have a name, using sanitized repr"
146 )
147 from muutils.misc import sanitize_identifier
149 method_name = sanitize_identifier(repr(method))
150 else:
151 method_name = method_name_orig
152 else:
153 method_name = custom_name
154 method.__name__ = custom_name
155 assert (
156 method_name not in method_dict
157 ), f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
158 method_dict[method_name] = method
159 return method
161 return decorator
164def pprint_summary(summary: dict):
165 print(json.dumps(summary, indent=2))