docs for muutils v0.9.0
View Source on GitHub

muutils.mlutils

miscellaneous utilities for ML pipelines


  1"miscellaneous utilities for ML pipelines"
  2
  3from __future__ import annotations
  4
  5import json
  6import os
  7import random
  8import typing
  9import warnings
 10from itertools import islice
 11from pathlib import Path
 12from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, Union
 13
 14ARRAY_IMPORTS: bool
 15try:
 16    import numpy as np
 17    import torch
 18    import torch.backends.mps
 19
 20    ARRAY_IMPORTS = True
 21except ImportError as e:
 22    warnings.warn(
 23        f"Numpy or torch not installed. Array operations will not be available.\n{e}"
 24    )
 25    ARRAY_IMPORTS = False
 26
 27DEFAULT_SEED: int = 42
 28GLOBAL_SEED: int = DEFAULT_SEED
 29
 30
 31def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
 32    """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
 33    if not ARRAY_IMPORTS:
 34        raise ImportError(
 35            "Numpy or torch not installed. Array operations will not be available."
 36        )
 37    try:
 38        # if device is given
 39        assert torch, "Torch is not available, cannot get device"  # pyright: ignore[reportPossiblyUnboundVariable]
 40        if device is not None:
 41            device = torch.device(device)
 42            if any(
 43                [
 44                    torch.cuda.is_available() and device.type == "cuda",
 45                    torch.backends.mps.is_available() and device.type == "mps",
 46                    device.type == "cpu",
 47                ]
 48            ):
 49                # if device is given and available
 50                pass
 51            else:
 52                warnings.warn(
 53                    f"Specified device {device} is not available, falling back to CPU"
 54                )
 55                return torch.device("cpu")
 56
 57        # no device given, infer from availability
 58        else:
 59            if torch.cuda.is_available():
 60                device = torch.device("cuda")
 61            elif torch.backends.mps.is_available():
 62                device = torch.device("mps")
 63            else:
 64                device = torch.device("cpu")
 65
 66        # put a dummy tensor on the device to check if it is available
 67        _dummy = torch.zeros(1, device=device)
 68
 69        return device
 70
 71    except Exception as e:
 72        warnings.warn(
 73            f"Error while getting device, falling back to CPU. Error: {e}",
 74            RuntimeWarning,
 75        )
 76        return torch.device("cpu")  # pyright: ignore[reportPossiblyUnboundVariable]
 77
 78
 79def set_reproducibility(seed: int = DEFAULT_SEED):
 80    """
 81    Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
 82
 83    Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
 84    off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
 85    """
 86    global GLOBAL_SEED
 87
 88    GLOBAL_SEED = seed
 89
 90    random.seed(seed)
 91
 92    if ARRAY_IMPORTS:
 93        try:
 94            assert np, "Numpy is not available, cannot set seed for numpy"  # pyright: ignore[reportPossiblyUnboundVariable]
 95            np.random.seed(seed)
 96        except Exception as e:
 97            warnings.warn(f"Error while setting seed for numpy: {e}", RuntimeWarning)
 98
 99        try:
100            assert torch, "Torch is not available, cannot set seed for torch"  # pyright: ignore[reportPossiblyUnboundVariable]
101            torch.manual_seed(seed)
102
103            torch.use_deterministic_algorithms(True)
104        except Exception as e:
105            warnings.warn(f"Error while setting seed for torch: {e}", RuntimeWarning)
106
107        # Ensure reproducibility for concurrent CUDA streams
108        # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
109        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
110
111
112T = TypeVar("T")
113
114
115def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]:
116    """Yield successive chunks from an iterator."""
117    # https://stackoverflow.com/a/61435714
118    iterator = iter(it)
119    while chunk := list(islice(iterator, chunk_size)):
120        yield chunk
121
122
123def get_checkpoint_paths_for_run(
124    run_path: Path,
125    extension: typing.Literal["pt", "zanj"],
126    checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
127) -> list[tuple[int, Path]]:
128    """get checkpoints of the format from the run_path
129
130    note that `checkpoints_format` should contain a glob pattern with:
131     - unresolved "{extension}" format term for the extension
132     - a wildcard for the iteration number
133    """
134
135    assert run_path.is_dir(), (
136        f"Model path {run_path} is not a directory (expect run directory, not model files)"
137    )
138
139    return [
140        (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
141        for checkpoint_path in sorted(
142            Path(run_path).glob(checkpoints_format.format(extension=extension))
143        )
144    ]
145
146
147F = TypeVar("F", bound=Callable[..., Any])
148
149
150def register_method(
151    method_dict: dict[str, Callable[..., Any]],
152    custom_name: Optional[str] = None,
153) -> Callable[[F], F]:
154    """Decorator to add a method to the method_dict"""
155
156    def decorator(method: F) -> F:
157        method_name: str
158        if custom_name is None:
159            method_name_orig: str | None = getattr(method, "__name__", None)
160            if method_name_orig is None:
161                warnings.warn(
162                    f"Method {method} does not have a name, using sanitized repr"
163                )
164                from muutils.misc import sanitize_identifier
165
166                method_name = sanitize_identifier(repr(method))
167            else:
168                method_name = method_name_orig
169        else:
170            method_name = custom_name
171            # TYPING: ty complains here
172            method.__name__ = custom_name  # type: ignore[unresolved-attribute]
173        assert method_name not in method_dict, (
174            f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
175        )
176        method_dict[method_name] = method
177        return method
178
179    return decorator
180
181
182def pprint_summary(summary: dict):
183    print(json.dumps(summary, indent=2))

ARRAY_IMPORTS: bool = True
DEFAULT_SEED: int = 42
GLOBAL_SEED: int = 42
def get_device(device: Union[str, torch.device, NoneType] = None) -> torch.device:
32def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
33    """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
34    if not ARRAY_IMPORTS:
35        raise ImportError(
36            "Numpy or torch not installed. Array operations will not be available."
37        )
38    try:
39        # if device is given
40        assert torch, "Torch is not available, cannot get device"  # pyright: ignore[reportPossiblyUnboundVariable]
41        if device is not None:
42            device = torch.device(device)
43            if any(
44                [
45                    torch.cuda.is_available() and device.type == "cuda",
46                    torch.backends.mps.is_available() and device.type == "mps",
47                    device.type == "cpu",
48                ]
49            ):
50                # if device is given and available
51                pass
52            else:
53                warnings.warn(
54                    f"Specified device {device} is not available, falling back to CPU"
55                )
56                return torch.device("cpu")
57
58        # no device given, infer from availability
59        else:
60            if torch.cuda.is_available():
61                device = torch.device("cuda")
62            elif torch.backends.mps.is_available():
63                device = torch.device("mps")
64            else:
65                device = torch.device("cpu")
66
67        # put a dummy tensor on the device to check if it is available
68        _dummy = torch.zeros(1, device=device)
69
70        return device
71
72    except Exception as e:
73        warnings.warn(
74            f"Error while getting device, falling back to CPU. Error: {e}",
75            RuntimeWarning,
76        )
77        return torch.device("cpu")  # pyright: ignore[reportPossiblyUnboundVariable]

Get the torch.device instance on which torch.Tensors should be allocated.

def set_reproducibility(seed: int = 42):
 80def set_reproducibility(seed: int = DEFAULT_SEED):
 81    """
 82    Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
 83
 84    Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
 85    off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
 86    """
 87    global GLOBAL_SEED
 88
 89    GLOBAL_SEED = seed
 90
 91    random.seed(seed)
 92
 93    if ARRAY_IMPORTS:
 94        try:
 95            assert np, "Numpy is not available, cannot set seed for numpy"  # pyright: ignore[reportPossiblyUnboundVariable]
 96            np.random.seed(seed)
 97        except Exception as e:
 98            warnings.warn(f"Error while setting seed for numpy: {e}", RuntimeWarning)
 99
100        try:
101            assert torch, "Torch is not available, cannot set seed for torch"  # pyright: ignore[reportPossiblyUnboundVariable]
102            torch.manual_seed(seed)
103
104            torch.use_deterministic_algorithms(True)
105        except Exception as e:
106            warnings.warn(f"Error while setting seed for torch: {e}", RuntimeWarning)
107
108        # Ensure reproducibility for concurrent CUDA streams
109        # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
110        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.

Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.

def chunks(it: Iterable[~T], chunk_size: int) -> Generator[list[~T], Any, NoneType]:
116def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]:
117    """Yield successive chunks from an iterator."""
118    # https://stackoverflow.com/a/61435714
119    iterator = iter(it)
120    while chunk := list(islice(iterator, chunk_size)):
121        yield chunk

Yield successive chunks from an iterator.

def get_checkpoint_paths_for_run( run_path: pathlib._local.Path, extension: Literal['pt', 'zanj'], checkpoints_format: str = 'checkpoints/model.iter_*.{extension}') -> list[tuple[int, pathlib._local.Path]]:
124def get_checkpoint_paths_for_run(
125    run_path: Path,
126    extension: typing.Literal["pt", "zanj"],
127    checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
128) -> list[tuple[int, Path]]:
129    """get checkpoints of the format from the run_path
130
131    note that `checkpoints_format` should contain a glob pattern with:
132     - unresolved "{extension}" format term for the extension
133     - a wildcard for the iteration number
134    """
135
136    assert run_path.is_dir(), (
137        f"Model path {run_path} is not a directory (expect run directory, not model files)"
138    )
139
140    return [
141        (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
142        for checkpoint_path in sorted(
143            Path(run_path).glob(checkpoints_format.format(extension=extension))
144        )
145    ]

get checkpoints of the format from the run_path

note that checkpoints_format should contain a glob pattern with:

  • unresolved "{extension}" format term for the extension
  • a wildcard for the iteration number
def register_method( method_dict: dict[str, typing.Callable[..., typing.Any]], custom_name: Optional[str] = None) -> Callable[[~F], ~F]:
151def register_method(
152    method_dict: dict[str, Callable[..., Any]],
153    custom_name: Optional[str] = None,
154) -> Callable[[F], F]:
155    """Decorator to add a method to the method_dict"""
156
157    def decorator(method: F) -> F:
158        method_name: str
159        if custom_name is None:
160            method_name_orig: str | None = getattr(method, "__name__", None)
161            if method_name_orig is None:
162                warnings.warn(
163                    f"Method {method} does not have a name, using sanitized repr"
164                )
165                from muutils.misc import sanitize_identifier
166
167                method_name = sanitize_identifier(repr(method))
168            else:
169                method_name = method_name_orig
170        else:
171            method_name = custom_name
172            # TYPING: ty complains here
173            method.__name__ = custom_name  # type: ignore[unresolved-attribute]
174        assert method_name not in method_dict, (
175            f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
176        )
177        method_dict[method_name] = method
178        return method
179
180    return decorator

Decorator to add a method to the method_dict

def pprint_summary(summary: dict):
183def pprint_summary(summary: dict):
184    print(json.dumps(summary, indent=2))