Coverage for muutils / mlutils.py: 40%
83 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
1"miscellaneous utilities for ML pipelines"
3from __future__ import annotations
5import json
6import os
7import random
8import typing
9import warnings
10from itertools import islice
11from pathlib import Path
12from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, Union
14ARRAY_IMPORTS: bool
15try:
16 import numpy as np
17 import torch
18 import torch.backends.mps
20 ARRAY_IMPORTS = True
21except ImportError as e:
22 warnings.warn(
23 f"Numpy or torch not installed. Array operations will not be available.\n{e}"
24 )
25 ARRAY_IMPORTS = False
27DEFAULT_SEED: int = 42
28GLOBAL_SEED: int = DEFAULT_SEED
31def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
32 """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
33 if not ARRAY_IMPORTS:
34 raise ImportError(
35 "Numpy or torch not installed. Array operations will not be available."
36 )
37 try:
38 # if device is given
39 assert torch, "Torch is not available, cannot get device" # pyright: ignore[reportPossiblyUnboundVariable]
40 if device is not None:
41 device = torch.device(device)
42 if any(
43 [
44 torch.cuda.is_available() and device.type == "cuda",
45 torch.backends.mps.is_available() and device.type == "mps",
46 device.type == "cpu",
47 ]
48 ):
49 # if device is given and available
50 pass
51 else:
52 warnings.warn(
53 f"Specified device {device} is not available, falling back to CPU"
54 )
55 return torch.device("cpu")
57 # no device given, infer from availability
58 else:
59 if torch.cuda.is_available():
60 device = torch.device("cuda")
61 elif torch.backends.mps.is_available():
62 device = torch.device("mps")
63 else:
64 device = torch.device("cpu")
66 # put a dummy tensor on the device to check if it is available
67 _dummy = torch.zeros(1, device=device)
69 return device
71 except Exception as e:
72 warnings.warn(
73 f"Error while getting device, falling back to CPU. Error: {e}",
74 RuntimeWarning,
75 )
76 return torch.device("cpu") # pyright: ignore[reportPossiblyUnboundVariable]
79def set_reproducibility(seed: int = DEFAULT_SEED):
80 """
81 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
83 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
84 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
85 """
86 global GLOBAL_SEED
88 GLOBAL_SEED = seed
90 random.seed(seed)
92 if ARRAY_IMPORTS:
93 try:
94 assert np, "Numpy is not available, cannot set seed for numpy" # pyright: ignore[reportPossiblyUnboundVariable]
95 np.random.seed(seed)
96 except Exception as e:
97 warnings.warn(f"Error while setting seed for numpy: {e}", RuntimeWarning)
99 try:
100 assert torch, "Torch is not available, cannot set seed for torch" # pyright: ignore[reportPossiblyUnboundVariable]
101 torch.manual_seed(seed)
103 torch.use_deterministic_algorithms(True)
104 except Exception as e:
105 warnings.warn(f"Error while setting seed for torch: {e}", RuntimeWarning)
107 # Ensure reproducibility for concurrent CUDA streams
108 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
109 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
112T = TypeVar("T")
115def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]:
116 """Yield successive chunks from an iterator."""
117 # https://stackoverflow.com/a/61435714
118 iterator = iter(it)
119 while chunk := list(islice(iterator, chunk_size)):
120 yield chunk
123def get_checkpoint_paths_for_run(
124 run_path: Path,
125 extension: typing.Literal["pt", "zanj"],
126 checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
127) -> list[tuple[int, Path]]:
128 """get checkpoints of the format from the run_path
130 note that `checkpoints_format` should contain a glob pattern with:
131 - unresolved "{extension}" format term for the extension
132 - a wildcard for the iteration number
133 """
135 assert run_path.is_dir(), (
136 f"Model path {run_path} is not a directory (expect run directory, not model files)"
137 )
139 return [
140 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
141 for checkpoint_path in sorted(
142 Path(run_path).glob(checkpoints_format.format(extension=extension))
143 )
144 ]
147F = TypeVar("F", bound=Callable[..., Any])
150def register_method(
151 method_dict: dict[str, Callable[..., Any]],
152 custom_name: Optional[str] = None,
153) -> Callable[[F], F]:
154 """Decorator to add a method to the method_dict"""
156 def decorator(method: F) -> F:
157 method_name: str
158 if custom_name is None:
159 method_name_orig: str | None = getattr(method, "__name__", None)
160 if method_name_orig is None:
161 warnings.warn(
162 f"Method {method} does not have a name, using sanitized repr"
163 )
164 from muutils.misc import sanitize_identifier
166 method_name = sanitize_identifier(repr(method))
167 else:
168 method_name = method_name_orig
169 else:
170 method_name = custom_name
171 # TYPING: ty complains here
172 method.__name__ = custom_name # type: ignore[unresolved-attribute]
173 assert method_name not in method_dict, (
174 f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
175 )
176 method_dict[method_name] = method
177 return method
179 return decorator
182def pprint_summary(summary: dict):
183 print(json.dumps(summary, indent=2))