# Stats - 69 files - 14082 (14K) lines - 493606 (494K) chars - 52418 (52K) `whitespace-split` tokens # File Tree ``` muutils ├── .claude ├── .github │ └── workflows │ ├── checks.yml [117L 3,019C 326T] │ └── make-docs.yml [ 48L 1,261C 140T] ├── muutils │ ├── cli │ │ ├── arg_bool.py [275L 8,997C 983T] │ │ └── command.py [ 94L 2,842C 277T] │ ├── json_serialize │ │ ├── __init__.py [ 51L 2,416C 297T] │ │ ├── array.py [353L 12,618C 1,215T] │ │ ├── dataclass_transform_mock.py [ 29L 841C 73T] │ │ ├── json_serialize.py [403L 15,218C 1,485T] │ │ ├── serializable_dataclass.py [938L 37,412C 3,662T] │ │ ├── serializable_field.py [309L 12,239C 1,252T] │ │ ├── types.py [ 45L 1,028C 103T] │ │ └── util.py [311L 11,310C 1,154T] │ ├── logger │ │ ├── __init__.py [ 30L 684C 55T] │ │ ├── exception_context.py [ 58L 1,597C 157T] │ │ ├── headerfuncs.py [ 68L 1,705C 209T] │ │ ├── log_util.py [ 86L 2,306C 287T] │ │ ├── logger.py [310L 10,890C 1,111T] │ │ ├── loggingstream.py [102L 3,995C 429T] │ │ ├── simplelogger.py [ 84L 2,333C 253T] │ │ └── timing.py [ 93L 2,756C 272T] │ ├── math │ │ ├── __init__.py [ 4L 47C 6T] │ │ ├── bins.py [ 73L 2,616C 192T] │ │ └── matrix_powers.py [164L 5,340C 653T] │ ├── misc │ │ ├── __init__.py [ 85L 1,978C 156T] │ │ ├── b64_decode.py [ 9L 319C 33T] │ │ ├── classes.py [101L 3,718C 433T] │ │ ├── freezing.py [121L 4,055C 422T] │ │ ├── func.py [285L 10,254C 1,071T] │ │ ├── hashing.py [ 39L 1,146C 132T] │ │ ├── numerical.py [165L 4,728C 531T] │ │ ├── sequence.py [234L 7,251C 877T] │ │ ├── string.py [132L 3,635C 404T] │ │ └── typing_breakdown.py [393L 14,834C 1,310T] │ ├── ml │ │ └── cuda_mem_info.py [ 55L 1,765C 198T] │ ├── nbutils │ │ ├── __init__.py [ 21L 488C 51T] │ │ ├── configure_notebook.py [320L 10,401C 1,073T] │ │ ├── convert_ipynb_to_script.py [374L 13,400C 1,219T] │ │ ├── mermaid.py [ 20L 634C 58T] │ │ ├── print_tex.py [ 23L 771C 92T] │ │ └── run_notebook_tests.py [255L 9,334C 829T] │ ├── web │ │ ├── __init__.py [ 3L 33C 5T] │ │ ├── bundle_html.py [388L 13,088C 1,271T] │ │ └── html_to_pdf.py [ 38L 1,025C 105T] │ ├── __init__.py [ 34L 544C 43T] │ ├── collect_warnings.py [132L 4,080C 378T] │ ├── console_unicode.py [ 34L 1,074C 135T] │ ├── dbg.py [543L 16,533C 1,808T] │ ├── dictmagic.py [526L 18,246C 1,982T] │ ├── errormode.py [241L 7,939C 816T] │ ├── group_equiv.py [ 66L 2,060C 246T] │ ├── interval.py [530L 18,387C 1,750T] │ ├── jsonlines.py [ 77L 1,993C 246T] │ ├── kappa.py [ 48L 1,464C 182T] │ ├── mlutils.py [183L 6,106C 587T] │ ├── parallel.py [280L 9,659C 1,118T] │ ├── py.typed [ 0L 0C 0T] │ ├── spinner.py [525L 18,312C 1,738T] │ ├── statcounter.py [231L 7,397C 771T] │ ├── sysinfo.py [210L 7,185C 514T] │ ├── tensor_info.py [725L 23,964C 2,367T] │ ├── tensor_utils.py [506L 16,115C 1,615T] │ ├── timeit_fancy.py [113L 4,329C 521T] │ └── validate_type.py [244L 8,661C 931T] ├── CHANGELOG.md [ 22L 853C 98T] ├── LICENSE [674L 35,149C 5,644T] ├── README.md [ 72L 5,710C 460T] ├── TODO.md [ 28L 1,388C 235T] ├── makefile [755L 30,080C 3,866T] ├── pyproject.toml [273L 8,032C 915T] ``` # File Contents ``````{ path=".github/workflows/checks.yml" } name: Checks on: pull_request: branches: - main - "*" push: branches: - main jobs: lint: name: Formatting runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 1 - name: install format tools run: pip install -r .meta/requirements/requirements-lint.txt - name: Run Format Checks run: make format-check RUN_GLOBAL=1 check-deps: name: Check dependencies runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 1 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.10' - name: set up uv run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: check dependencies run: make dep-check test: name: Test and Lint runs-on: ubuntu-latest # needs: [lint, check-deps] # for conditionally running this job strategy: matrix: python: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] pkg: - group: "legacy" torch: "1.13.1" numpy: "1.24.4" pandas: "2.0.3" pillow: "10.4.0" - group: "latest" torch: "" numpy: "" pandas: "" pillow: "" exclude: - python: "3.12" pkg: group: "legacy" - python: "3.13" pkg: group: "legacy" - python: "3.14" pkg: group: "legacy" steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 1 - name: Set up python uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: set up uv run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: install run: make setup - name: Install different pytorch version if: ${{ matrix.pkg.torch != '' && matrix.python != '3.14' }} run: | uv pip install torch==${{ matrix.pkg.torch }}+cpu --extra-index-url https://download.pytorch.org/whl/cpu - name: Install legacy package versions if: ${{ matrix.pkg.group == 'legacy' }} run: uv pip install numpy==${{ matrix.pkg.numpy }} pandas==${{ matrix.pkg.pandas }} pillow==${{ matrix.pkg.pillow }} - name: tests run: make test UV_NOSYNC=1 # - name: tests in strict mode # # TODO: until zanj ported to 3.8 and 3.9 # if: ${{ matrix.python != '3.8' && matrix.python != '3.9' }} # run: make test WARN_STRICT=1 - name: "check typing" if: ${{ matrix.python != '3.8' }} run: make typing UV_NOSYNC=1 - name: "check typing (legacy: log only, always pass)" if: ${{ matrix.python == '3.8' }} run: make typing UV_NOSYNC=1 || true ``````{ end_of_file=".github/workflows/checks.yml" } ``````{ path=".github/workflows/make-docs.yml" } # this workflow partially copied from # https://github.com/TransformerLensOrg/TransformerLens/blob/main/.github/workflows/checks.yml name: make docs on: pull_request: branches: - main - "*" push: branches: - main jobs: build-docs: # When running on a PR, this just checks we can build the docs without errors # When running on merge to main, it builds the docs and then another job deploys them name: 'Build Docs' runs-on: ubuntu-latest if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev') || contains(github.head_ref, 'docs') steps: - name: Install pandoc uses: awalsh128/cache-apt-pkgs-action@latest with: packages: pandoc version: '3.3' - name: Check pandoc version run: pandoc --version - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.13' - name: set up uv run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: Install run: make setup - name: Build Docs run: make docs ``````{ end_of_file=".github/workflows/make-docs.yml" } ``````{ path="muutils/cli/arg_bool.py" } from __future__ import annotations import argparse import sys from collections.abc import Iterable, Sequence from typing import Any, Callable, Final, TypeVar if sys.version_info >= (3, 12): from typing import override else: from typing_extensions import override T_callable = TypeVar("T_callable", bound=Callable[..., Any]) def format_function_docstring( mapping: dict[str, Any], /, ) -> Callable[[T_callable], T_callable]: """Decorator to format function docstring with the given keyword arguments""" # I think we don't need to use functools.wraps here, since we return the same function def decorator(func: T_callable) -> T_callable: assert func.__doc__ is not None, "Function must have a docstring to format." func.__doc__ = func.__doc__.format_map(mapping) return func return decorator # Default token sets (lowercase). You can override per-option. TRUE_SET_DEFAULT: Final[set[str]] = {"1", "true", "t", "yes", "y", "on"} FALSE_SET_DEFAULT: Final[set[str]] = {"0", "false", "f", "no", "n", "off"} def _normalize_set(tokens: Iterable[str] | None, fallback: set[str]) -> set[str]: """Normalize a collection of tokens to a lowercase set, or return fallback.""" if tokens is None: return set(fallback) return {str(t).lower() for t in tokens} def parse_bool_token( token: str, true_set: set[str] | None = None, false_set: set[str] | None = None, ) -> bool: """Strict string-to-bool converter for argparse and friends. # Parameters: - `token : str` input token - `true_set : set[str] | None` accepted truthy strings (case-insensitive). Defaults to TRUE_SET_DEFAULT when None. - `false_set : set[str] | None` accepted falsy strings (case-insensitive). Defaults to FALSE_SET_DEFAULT when None. # Returns: - `bool` parsed boolean # Raises: - `argparse.ArgumentTypeError` : if not a recognized boolean string """ ts: set[str] = _normalize_set(true_set, TRUE_SET_DEFAULT) fs: set[str] = _normalize_set(false_set, FALSE_SET_DEFAULT) v: str = token.lower() if v in ts: return True if v in fs: return False valid: list[str] = sorted(ts | fs) raise argparse.ArgumentTypeError(f"expected one of {valid}") class BoolFlagOrValue(argparse.Action): """summary Configurable boolean action supporting any combination of: --flag -> True (if allow_bare) --no-flag -> False (if allow_no and --no-flag is registered) --flag true|false -> parsed via custom sets --flag=true|false -> parsed via custom sets Notes: - The --no-flag form never accepts a value. It forces False. - If allow_no is False but you still register a --no-flag alias, using it will produce a usage error. - Do not pass type= to this action. # Parameters: - `option_strings : list[str]` provided by argparse - `dest : str` attribute name on the namespace - `nargs : int | str | None` must be '?' for optional value - `true_set : set[str] | None` accepted truthy strings (case-insensitive). Defaults provided. - `false_set : set[str] | None` accepted falsy strings (case-insensitive). Defaults provided. - `allow_no : bool` whether the --no-flag form is allowed (defaults to True) - `allow_bare : bool` whether bare --flag (no value) is allowed (defaults to True) - `**kwargs` forwarded to base class # Raises: - `ValueError` : if nargs is not '?' or if type= is provided """ def __init__( self, option_strings: Sequence[str], dest: str, nargs: int | str | None = None, true_set: set[str] | None = None, false_set: set[str] | None = None, allow_no: bool = True, allow_bare: bool = True, **kwargs: Any, ) -> None: if "type" in kwargs and kwargs["type"] is not None: raise ValueError("BoolFlagOrValue does not accept type=. Remove it.") if nargs not in (None, "?"): raise ValueError("BoolFlagOrValue requires nargs='?'") super().__init__( option_strings=option_strings, dest=dest, nargs="?", **kwargs, ) # Store normalized config self.true_set: set[str] = _normalize_set(true_set, TRUE_SET_DEFAULT) self.false_set: set[str] = _normalize_set(false_set, FALSE_SET_DEFAULT) self.allow_no: bool = allow_no self.allow_bare: bool = allow_bare def _parse_token(self, token: str) -> bool: """Parse a boolean token using this action's configured sets.""" return parse_bool_token(token, self.true_set, self.false_set) @override def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str | Sequence[str] | None, option_string: str | None = None, ) -> None: # Negated form handling if option_string is not None and option_string.startswith("--no-"): if not self.allow_no: parser.error(f"{option_string} is not allowed for this option") return if values is not None: dest_flag: str = self.dest.replace("_", "-") parser.error( f"{option_string} does not take a value; use --{dest_flag} true|false" ) return setattr(namespace, self.dest, False) return # Bare positive flag -> True (if allowed) if values is None: if not self.allow_bare: valid: list[str] = sorted(self.true_set | self.false_set) parser.error( f"option {option_string} requires a value; expected one of {valid}" ) return setattr(namespace, self.dest, True) return # we take only one value if not isinstance(values, str): if len(values) != 1: parser.error( f"{option_string} expects a single value, got {len(values) = }, {values = }" ) return values = values[0] # type: ignore[assignment] # Positive flag with explicit value -> parse try: val: bool = self._parse_token(values) except argparse.ArgumentTypeError as e: parser.error(str(e)) return setattr(namespace, self.dest, val) def add_bool_flag( parser: argparse.ArgumentParser, name: str, *, default: bool = False, help: str = "", true_set: set[str] | None = None, false_set: set[str] | None = None, allow_no: bool = False, allow_bare: bool = True, ) -> None: """summary Add a configurable boolean option that supports (depending on options): -- (bare positive, if allow_bare) --no- (negated, if allow_no) -- true|false --=true|false # Parameters: - `parser : argparse.ArgumentParser` parser to modify - `name : str` base long option name (without leading dashes) - `default : bool` default value (defaults to False) - `help : str` help text (optional) - `true_set : set[str] | None` accepted truthy strings (case-insensitive). Defaults used when None. - `false_set : set[str] | None` accepted falsy strings (case-insensitive). Defaults used when None. - `allow_no : bool` whether to register/allow the --no- alias (defaults to True) - `allow_bare : bool` whether bare -- implies True (defaults to True) # Returns: - `None` nothing; parser is modified # Modifies: - `parser` : adds a new argument with dest `` (hyphens -> underscores) # Usage: ```python p = argparse.ArgumentParser() add_bool_flag(p, "feature", default=False, help="enable/disable feature") ns = p.parse_args(["--feature=false"]) assert ns.feature is False ``` """ long_opt: str = f"--{name}" dest: str = name.replace("-", "_") option_strings: list[str] = [long_opt] if allow_no: option_strings.append(f"--no-{name}") tokens_preview: str = "{true,false}" readable_name: str = name.replace("-", " ") arg_help: str = help or ( f"enable/disable {readable_name}; also accepts explicit true|false" ) parser.add_argument( *option_strings, dest=dest, action=BoolFlagOrValue, nargs="?", default=default, metavar=tokens_preview, help=arg_help, true_set=true_set, false_set=false_set, allow_no=allow_no, allow_bare=allow_bare, ) ``````{ end_of_file="muutils/cli/arg_bool.py" } ``````{ path="muutils/cli/command.py" } from __future__ import annotations import os import subprocess import sys from dataclasses import dataclass from typing import Any, List, Union @dataclass class Command: """Simple typed command with shell flag and subprocess helpers.""" cmd: Union[List[str], str] shell: bool = False env: dict[str, str] | None = None inherit_env: bool = True def __post_init__(self) -> None: """Enforce cmd type when shell is False.""" if self.shell is False and isinstance(self.cmd, str): raise ValueError("cmd must be List[str] when shell is False") def _quote_env(self) -> str: """Return KEY=VAL tokens for env values. ignores `inherit_env`.""" if not self.env: return "" parts: List[str] = [] for k, v in self.env.items(): token: str = f"{k}={v}" parts.append(token) prefix: str = " ".join(parts) return prefix @property def cmd_joined(self) -> str: """Return cmd as a single string, joining with spaces if it's a list. no env included.""" if isinstance(self.cmd, str): return self.cmd else: return " ".join(self.cmd) @property def cmd_for_subprocess(self) -> Union[List[str], str]: """Return cmd, splitting if shell is True and cmd is a string.""" if self.shell: if isinstance(self.cmd, str): return self.cmd else: return " ".join(self.cmd) else: assert isinstance(self.cmd, list) return self.cmd def script_line(self) -> str: """Return a single shell string, prefixing KEY=VAL for env if provided.""" return f"{self._quote_env()} {self.cmd_joined}".strip() @property def env_final(self) -> dict[str, str]: """Return final env dict, merging with os.environ if inherit_env is True.""" return { **(os.environ if self.inherit_env else {}), **(self.env or {}), } def run( self, **kwargs: Any, ) -> subprocess.CompletedProcess[Any]: """Call subprocess.run with this command.""" try: return subprocess.run( self.cmd_for_subprocess, shell=self.shell, env=self.env_final, **kwargs, ) except subprocess.CalledProcessError as e: print(f"Command failed: `{self.script_line()}`", file=sys.stderr) raise e def Popen( self, **kwargs: Any, ) -> subprocess.Popen[Any]: """Call subprocess.Popen with this command.""" return subprocess.Popen( self.cmd_for_subprocess, shell=self.shell, env=self.env_final, **kwargs, ) ``````{ end_of_file="muutils/cli/command.py" } ``````{ path="muutils/json_serialize/__init__.py" } """submodule for serializing things to json in a recoverable way you can throw *any* object into `muutils.json_serialize.json_serialize` and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`. The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ). it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre` additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields. This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes. """ from __future__ import annotations from muutils.json_serialize.array import arr_metadata, load_array from muutils.json_serialize.json_serialize import ( BASE_HANDLERS, JsonSerializer, json_serialize, ) from muutils.json_serialize.serializable_dataclass import ( SerializableDataclass, serializable_dataclass, serializable_field, ) from muutils.json_serialize.util import try_catch, JSONitem, dc_eq __all__ = [ # submodules "array", "json_serialize", "serializable_dataclass", "serializable_field", "util", # imports "arr_metadata", "load_array", "BASE_HANDLERS", "JSONitem", "JsonSerializer", "json_serialize", "try_catch", "JSONitem", "dc_eq", "serializable_dataclass", "serializable_field", "SerializableDataclass", ] ``````{ end_of_file="muutils/json_serialize/__init__.py" } ``````{ path="muutils/json_serialize/array.py" } """this utilities module handles serialization and loading of numpy and torch arrays as json - `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability. - `array_b64_meta` is the most efficient, but is not human readable. - `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ) """ from __future__ import annotations import base64 import typing import warnings from typing import ( TYPE_CHECKING, Any, Iterable, Literal, Optional, Sequence, TypedDict, Union, overload, ) try: import numpy as np except ImportError as e: warnings.warn( f"numpy is not installed, array serialization will not work: \n{e}", ImportWarning, ) if TYPE_CHECKING: import numpy as np import torch from muutils.json_serialize.json_serialize import JsonSerializer from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage] # TYPING: pyright complains way too much here # pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false # Recursive type for nested numeric lists (output of arr.tolist()) NumericList = typing.Union[ typing.List[typing.Union[int, float, bool]], typing.List["NumericList"], ] ArrayMode = Literal[ "list", "array_list_meta", "array_hex_meta", "array_b64_meta", "external", "zero_dim", ] # Modes that produce SerializedArrayWithMeta (dict with metadata) ArrayModeWithMeta = Literal[ "array_list_meta", "array_hex_meta", "array_b64_meta", "zero_dim", "external", ] def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: ignore[reportAny] """get the number of elements in an array""" if isinstance(arr, np.ndarray): return arr.size elif str(type(arr)) == "": # pyright: ignore[reportUnknownArgumentType, reportAny] assert hasattr(arr, "nelement"), ( "torch Tensor does not have nelement() method? this should not happen" ) # pyright: ignore[reportAny] return arr.nelement() # pyright: ignore[reportAny] else: raise TypeError(f"invalid type: {type(arr)}") # pyright: ignore[reportAny] class ArrayMetadata(TypedDict): """Metadata for a numpy/torch array""" shape: list[int] dtype: str n_elements: int class SerializedArrayWithMeta(TypedDict): """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)""" __muutils_format__: str data: typing.Union[ NumericList, str, int, float, bool ] # list, hex str, b64 str, or scalar for zero_dim shape: list[int] dtype: str n_elements: int def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny] """get metadata for a numpy array""" return { "shape": list(arr.shape), # pyright: ignore[reportAny] "dtype": ( arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) # pyright: ignore[reportAny] ), "n_elements": array_n_elements(arr), } @overload def serialize_array( jser: "JsonSerializer", arr: "Union[np.ndarray, torch.Tensor]", path: str | Sequence[str | int], array_mode: Literal["list"], ) -> NumericList: ... @overload def serialize_array( jser: "JsonSerializer", arr: "Union[np.ndarray, torch.Tensor]", path: str | Sequence[str | int], array_mode: ArrayModeWithMeta, ) -> SerializedArrayWithMeta: ... @overload def serialize_array( jser: "JsonSerializer", arr: "Union[np.ndarray, torch.Tensor]", path: str | Sequence[str | int], array_mode: None = None, ) -> SerializedArrayWithMeta | NumericList: ... def serialize_array( jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 arr: "Union[np.ndarray, torch.Tensor]", path: str | Sequence[str | int], # pyright: ignore[reportUnusedParameter] array_mode: ArrayMode | None = None, ) -> SerializedArrayWithMeta | NumericList: """serialize a numpy or pytorch array in one of several modes if the object is zero-dimensional, simply get the unique item `array_mode: ArrayMode` can be one of: - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`) - `array_list_meta`: serialize dict with metadata, actual list under the key `data` - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data` - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data` for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is: ``` { _FORMAT_KEY: , "shape": arr.shape, "dtype": str(arr.dtype), "data": , } ``` # Parameters: - `arr : Any` array to serialize - `array_mode : ArrayMode` mode in which to serialize the array (defaults to `None` and inheriting from `jser: JsonSerializer`) # Returns: - `JSONitem` json serialized array # Raises: - `KeyError` : if the array mode is not valid """ if array_mode is None: array_mode = jser.array_mode arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) # pyright: ignore[reportUnnecessaryIsInstance] # Handle list mode first (no metadata needed) if array_mode == "list": return arr_np.tolist() # pyright: ignore[reportAny] # For all other modes, compute metadata once metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np) # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe? # handle zero-dimensional arrays if len(arr.shape) == 0: return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:zero_dim", data=arr.item(), # pyright: ignore[reportAny] shape=metadata["shape"], dtype=metadata["dtype"], n_elements=metadata["n_elements"], ) # Handle the metadata modes if array_mode == "array_list_meta": return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:array_list_meta", data=arr_np.tolist(), # pyright: ignore[reportAny] shape=metadata["shape"], dtype=metadata["dtype"], n_elements=metadata["n_elements"], ) elif array_mode == "array_hex_meta": return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:array_hex_meta", data=arr_np.tobytes().hex(), shape=metadata["shape"], dtype=metadata["dtype"], n_elements=metadata["n_elements"], ) elif array_mode == "array_b64_meta": return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:array_b64_meta", data=base64.b64encode(arr_np.tobytes()).decode(), shape=metadata["shape"], dtype=metadata["dtype"], n_elements=metadata["n_elements"], ) else: raise KeyError(f"invalid array_mode: {array_mode}") @overload def infer_array_mode( arr: SerializedArrayWithMeta, ) -> ArrayModeWithMeta: ... @overload def infer_array_mode(arr: NumericList) -> Literal["list"]: ... def infer_array_mode( arr: Union[SerializedArrayWithMeta, NumericList], ) -> ArrayMode: """given a serialized array, infer the mode assumes the array was serialized via `serialize_array()` """ return_mode: ArrayMode if isinstance(arr, typing.Mapping): # _FORMAT_KEY always maps to a string fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore if fmt.endswith(":array_list_meta"): arr_data = arr["data"] # ty: ignore[invalid-argument-type] if not isinstance(arr_data, Iterable): raise ValueError(f"invalid list format: {type(arr_data) = }\t{arr}") return_mode = "array_list_meta" elif fmt.endswith(":array_hex_meta"): arr_data = arr["data"] # ty: ignore[invalid-argument-type] if not isinstance(arr_data, str): raise ValueError(f"invalid hex format: {type(arr_data) = }\t{arr}") return_mode = "array_hex_meta" elif fmt.endswith(":array_b64_meta"): arr_data = arr["data"] # ty: ignore[invalid-argument-type] if not isinstance(arr_data, str): raise ValueError(f"invalid b64 format: {type(arr_data) = }\t{arr}") return_mode = "array_b64_meta" elif fmt.endswith(":external"): return_mode = "external" elif fmt.endswith(":zero_dim"): return_mode = "zero_dim" else: raise ValueError(f"invalid format: {arr}") elif isinstance(arr, list): # pyright: ignore[reportUnnecessaryIsInstance] return_mode = "list" else: raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") # pyright: ignore[reportUnreachable] return return_mode @overload def load_array( arr: SerializedArrayWithMeta, array_mode: Optional[ArrayModeWithMeta] = None, ) -> np.ndarray: ... @overload def load_array( arr: NumericList, array_mode: Optional[Literal["list"]] = None, ) -> np.ndarray: ... @overload def load_array( arr: np.ndarray, array_mode: None = None, ) -> np.ndarray: ... def load_array( arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList], array_mode: Optional[ArrayMode] = None, ) -> np.ndarray: """load a json-serialized array, infer the mode if not specified""" # return arr if its already a numpy array if isinstance(arr, np.ndarray): assert array_mode is None, ( "array_mode should not be specified when loading a numpy array, since that is a no-op" ) return arr # try to infer the array_mode array_mode_inferred: ArrayMode = infer_array_mode(arr) if array_mode is None: array_mode = array_mode_inferred elif array_mode != array_mode_inferred: warnings.warn( f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" ) # actually load the array if array_mode == "array_list_meta": assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore if tuple(arr["shape"]) != tuple(data.shape): # type: ignore raise ValueError(f"invalid shape: {arr}") return data elif array_mode == "array_hex_meta": assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore return data.reshape(arr["shape"]) # type: ignore elif array_mode == "array_b64_meta": assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore return data.reshape(arr["shape"]) # type: ignore elif array_mode == "list": assert isinstance(arr, typing.Sequence), ( f"invalid list format: {type(arr) = }\n{arr = }" ) return np.array(arr) # type: ignore elif array_mode == "external": assert isinstance(arr, typing.Mapping) if "data" not in arr: raise KeyError( # pyright: ignore[reportUnreachable] f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" ) # we can ignore here since we assume ZANJ has taken care of it return arr["data"] # type: ignore[return-value] # pyright: ignore[reportReturnType] elif array_mode == "zero_dim": assert isinstance(arr, typing.Mapping) data = np.array(arr["data"]) # ty: ignore[invalid-argument-type] if tuple(arr["shape"]) != tuple(data.shape): # type: ignore raise ValueError(f"invalid shape: {arr}") return data else: raise ValueError(f"invalid array_mode: {array_mode}") # pyright: ignore[reportUnreachable] ``````{ end_of_file="muutils/json_serialize/array.py" } ``````{ path="muutils/json_serialize/dataclass_transform_mock.py" } from __future__ import annotations import typing from typing import Any, Union def dataclass_transform( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, frozen_default: bool = False, field_specifiers: tuple[Union[type[Any], typing.Callable[..., Any]], ...] = (), **kwargs: Any, ) -> typing.Callable[[Any], Any]: "mock `typing.dataclass_transform` for python <3.11" def decorator(cls_or_fn: Any) -> Any: cls_or_fn.__dataclass_transform__ = { "eq_default": eq_default, "order_default": order_default, "kw_only_default": kw_only_default, "frozen_default": frozen_default, "field_specifiers": field_specifiers, "kwargs": kwargs, } return cls_or_fn return decorator ``````{ end_of_file="muutils/json_serialize/dataclass_transform_mock.py" } ``````{ path="muutils/json_serialize/json_serialize.py" } """provides the basic framework for json serialization of objects notably: - `SerializerHandler` defines how to serialize a specific type of object - `JsonSerializer` handles configuration for which handlers to use - `json_serialize` provides the default configuration if you don't care -- call it on any object! """ from __future__ import annotations import inspect import warnings from dataclasses import dataclass, is_dataclass from pathlib import Path from typing import ( TYPE_CHECKING, Any, Callable, Iterable, Mapping, Set, Union, cast, overload, ) from muutils.errormode import ErrorMode if TYPE_CHECKING: # always need array.py for type checking from muutils.json_serialize.array import ArrayMode, serialize_array else: try: from muutils.json_serialize.array import ArrayMode, serialize_array except ImportError as e: # TYPING: obviously, these types are all wrong if we can't import array.py ArrayMode = str # type: ignore[misc] serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731 warnings.warn( f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", ImportWarning, ) from muutils.json_serialize.types import ( _FORMAT_KEY, Hashableitem, ) # pyright: ignore[reportPrivateUsage] from muutils.json_serialize.util import ( JSONdict, JSONitem, MonoTuple, SerializationException, _recursive_hashify, # pyright: ignore[reportPrivateUsage, reportUnknownVariableType] isinstance_namedtuple, safe_getsource, string_as_lines, try_catch, ) # pylint: disable=protected-access SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = ( "__name__", "__doc__", "__module__", "__class__", "__dict__", "__annotations__", ) SERIALIZER_SPECIAL_FUNCS: dict[str, Callable[..., str | list[str]]] = { "str": str, "dir": dir, "type": try_catch(lambda x: str(type(x).__name__)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] "repr": try_catch(lambda x: repr(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] "code": try_catch(lambda x: inspect.getsource(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] "sourcefile": try_catch(lambda x: str(inspect.getsourcefile(x))), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] } SERIALIZE_DIRECT_AS_STR: Set[str] = { "", "", } ObjectPath = MonoTuple[Union[str, int]] @dataclass class SerializerHandler: """a handler for a specific type of object # Parameters: - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object - `desc : str` description of the handler (optional) """ # (self_config, object) -> whether to use this handler check: Callable[["JsonSerializer", Any, ObjectPath], bool] # (self_config, object, path) -> serialized object serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem] # unique identifier for the handler uid: str # description of this serializer desc: str def serialize(self) -> JSONdict: """serialize the handler info""" return { # get the code and doc of the check function "check": { "code": safe_getsource(self.check), "doc": string_as_lines(self.check.__doc__), }, # get the code and doc of the load function "serialize_func": { "code": safe_getsource(self.serialize_func), "doc": string_as_lines(self.serialize_func.__doc__), }, # get the uid, source_pckg, priority, and desc "uid": str(self.uid), "source_pckg": getattr(self.serialize_func, "source_pckg", None), "__module__": getattr(self.serialize_func, "__module__", None), "desc": str(self.desc), } BASE_HANDLERS: MonoTuple[SerializerHandler] = ( SerializerHandler( check=lambda self, obj, path: isinstance( obj, (bool, int, float, str, type(None)) ), serialize_func=lambda self, obj, path: obj, uid="base types", desc="base types (bool, int, float, str, None)", ), SerializerHandler( check=lambda self, obj, path: isinstance(obj, Mapping), serialize_func=lambda self, obj, path: { str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items() }, uid="dictionaries", desc="dictionaries", ), SerializerHandler( check=lambda self, obj, path: isinstance_namedtuple(obj), serialize_func=lambda self, obj, path: { str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj._asdict().items() }, uid="namedtuple -> dict", desc="namedtuples as dicts", ), SerializerHandler( check=lambda self, obj, path: isinstance(obj, (list, tuple)), serialize_func=lambda self, obj, path: [ self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) ], uid="(list, tuple) -> list", desc="lists and tuples as lists", ), ) def _serialize_override_serialize_func( self: "JsonSerializer", obj: Any, path: ObjectPath ) -> JSONitem: # obj_cls: type = type(obj) # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self): # obj_cls._register_self() # get the serialized object return obj.serialize() DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + ( SerializerHandler( # TODO: allow for custom serialization handler name check=lambda self, obj, path: ( hasattr(obj, "serialize") and callable(obj.serialize) ), serialize_func=_serialize_override_serialize_func, uid=".serialize override", desc="objects with .serialize method", ), SerializerHandler( check=lambda self, obj, path: is_dataclass(obj), serialize_func=lambda self, obj, path: { k: self.json_serialize(getattr(obj, k), tuple(path) + (k,)) for k in obj.__dataclass_fields__ }, uid="dataclass -> dict", desc="dataclasses as dicts", ), SerializerHandler( check=lambda self, obj, path: isinstance(obj, Path), serialize_func=lambda self, obj, path: obj.as_posix(), uid="path -> str", desc="Path objects as posix strings", ), SerializerHandler( check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR, serialize_func=lambda self, obj, path: str(obj), uid="obj -> str(obj)", desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings", ), SerializerHandler( check=lambda self, obj, path: str(type(obj)) == "", serialize_func=lambda self, obj, path: cast( JSONitem, serialize_array(self, obj, path=path) ), uid="numpy.ndarray", desc="numpy arrays", ), SerializerHandler( check=lambda self, obj, path: str(type(obj)) == "", serialize_func=lambda self, obj, path: cast( JSONitem, serialize_array( self, obj.detach().cpu(), path=path, # pyright: ignore[reportAny] ), ), uid="torch.Tensor", desc="pytorch tensors", ), SerializerHandler( check=lambda self, obj, path: ( str(type(obj)) == "" ), # TYPING: type checkers have no idea that obj is a DataFrame here serialize_func=lambda self, obj, path: { # pyright: ignore[reportArgumentType, reportAny] _FORMAT_KEY: "pandas.DataFrame", # type: ignore[misc] "columns": obj.columns.tolist(), # pyright: ignore[reportAny] "data": obj.to_dict(orient="records"), # pyright: ignore[reportAny] "path": path, }, uid="pandas.DataFrame", desc="pandas DataFrames", ), SerializerHandler( check=lambda self, obj, path: isinstance(obj, (set, frozenset)), serialize_func=lambda self, obj, path: { _FORMAT_KEY: "set" if isinstance(obj, set) else "frozenset", # type: ignore[misc] "data": [ self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) ], }, uid="set -> dict[_FORMAT_KEY: 'set', data: list(...)]", desc="sets as dicts with format key", ), SerializerHandler( check=lambda self, obj, path: ( isinstance(obj, Iterable) and not isinstance(obj, (list, tuple, str)) ), serialize_func=lambda self, obj, path: [ self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) ], uid="Iterable -> list", desc="Iterables (not lists/tuples/strings) as lists", ), SerializerHandler( check=lambda self, obj, path: True, serialize_func=lambda self, obj, path: { **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS}, # type: ignore[typeddict-item] **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()}, }, uid="fallback", desc="fallback handler -- serialize object attributes and special functions as strings", ), ) class JsonSerializer: """Json serialization class (holds configs) # Parameters: - `array_mode : ArrayMode` how to write arrays (defaults to `"array_list_meta"`) - `error_mode : ErrorMode` what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to `"except"`) - `handlers_pre : MonoTuple[SerializerHandler]` handlers to use before the default handlers (defaults to `tuple()`) - `handlers_default : MonoTuple[SerializerHandler]` default handlers to use (defaults to `DEFAULT_HANDLERS`) - `write_only_format : bool` changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults to `False`) # Raises: - `ValueError`: on init, if `args` is not empty - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` """ def __init__( self, *args: None, array_mode: "ArrayMode" = "array_list_meta", error_mode: ErrorMode = ErrorMode.EXCEPT, handlers_pre: MonoTuple[SerializerHandler] = (), handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, write_only_format: bool = False, ): if len(args) > 0: raise ValueError( f"JsonSerializer takes no positional arguments!\n{args = }" ) self.array_mode: "ArrayMode" = array_mode self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) self.write_only_format: bool = write_only_format # join up the handlers self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( handlers_default ) @overload def json_serialize( self, obj: Mapping[str, Any], path: ObjectPath = () ) -> JSONdict: ... @overload def json_serialize(self, obj: list, path: ObjectPath = ()) -> list: ... # @overload # pyright: ignore[reportOverlappingOverload] # def json_serialize(self, obj: set, path: ObjectPath = ()) -> _SerializedSet: ... # @overload # def json_serialize( # self, obj: frozenset, path: ObjectPath = () # ) -> _SerializedFrozenset: ... @overload def json_serialize(self, obj: Any, path: ObjectPath = ()) -> JSONitem: ... def json_serialize( self, obj: Any, # pyright: ignore[reportAny] path: ObjectPath = (), ) -> JSONitem: handler = None try: for handler in self.handlers: if handler.check(self, obj, path): output: JSONitem = handler.serialize_func(self, obj, path) if self.write_only_format: if isinstance(output, dict) and _FORMAT_KEY in output: # TYPING: JSONitem has no idea that _FORMAT_KEY is str new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore # pyright: ignore[reportAssignmentType] output["__write_format__"] = new_fmt # type: ignore return output raise ValueError(f"no handler found for object with {type(obj) = }") # pyright: ignore[reportAny] except Exception as e: if self.error_mode == ErrorMode.EXCEPT: obj_str: str = repr(obj) # pyright: ignore[reportAny] if len(obj_str) > 1000: obj_str = obj_str[:1000] + "..." handler_uid = handler.uid if handler else "no handler matched" raise SerializationException( f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}" ) from e elif self.error_mode == ErrorMode.WARN: warnings.warn( f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" ) return repr(obj) # pyright: ignore[reportAny] def hashify( self, obj: Any, # pyright: ignore[reportAny] path: ObjectPath = (), force: bool = True, ) -> Hashableitem: """try to turn any object into something hashable""" data = self.json_serialize(obj, path=path) # recursive hashify, turning dicts and lists into tuples return _recursive_hashify(data, force=force) GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer() @overload def json_serialize(obj: Mapping[str, Any], path: ObjectPath = ()) -> JSONdict: ... @overload def json_serialize(obj: list, path: ObjectPath = ()) -> list: ... @overload # pyright: ignore[reportOverlappingOverload] # def json_serialize(obj: set, path: ObjectPath = ()) -> _SerializedSet: ... # @overload # def json_serialize(obj: frozenset, path: ObjectPath = ()) -> _SerializedFrozenset: ... @overload def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: ... def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: # pyright: ignore[reportAny] """serialize object to json-serializable object with default config""" return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path) ``````{ end_of_file="muutils/json_serialize/json_serialize.py" } ``````{ path="muutils/json_serialize/serializable_dataclass.py" } """save and load objects to and from json or compatible formats in a recoverable way `d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, you will get an error when you call `json.dumps(d)`. This module provides a way around that. Instead, you define your class: ```python @serializable_dataclass class MyClass(SerializableDataclass): a: int b: str ``` and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: >>> my_obj = MyClass(a=1, b="q") >>> s = json.dumps(my_obj.serialize()) >>> s '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' >>> read_obj = MyClass.load(json.loads(s)) >>> read_obj == my_obj True This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default: ```python @serializable_dataclass class NestedClass(SerializableDataclass): x: str y: MyClass act_fun: torch.nn.Module = serializable_field( default=torch.nn.ReLU(), serialization_fn=lambda x: str(x), deserialize_fn=lambda x: getattr(torch.nn, x)(), ) ``` which gives us: >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) >>> s = json.dumps(nc.serialize()) >>> s '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' >>> read_nc = NestedClass.load(json.loads(s)) >>> read_nc == nc True """ from __future__ import annotations import abc import dataclasses import functools import json import sys import typing import warnings from typing import Any, Optional, Type, TypeVar, overload, TYPE_CHECKING from muutils.errormode import ErrorMode from muutils.validate_type import validate_type from muutils.json_serialize.serializable_field import ( SerializableField, serializable_field, ) from muutils.json_serialize.types import _FORMAT_KEY from muutils.json_serialize.util import ( JSONdict, array_safe_eq, dc_eq, ) # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access # For type checkers: always use typing_extensions which they can resolve # At runtime: use stdlib if available (3.11+), else typing_extensions, else mock if TYPE_CHECKING: from typing_extensions import dataclass_transform, Self else: if sys.version_info >= (3, 11): from typing import dataclass_transform, Self else: try: from typing_extensions import dataclass_transform, Self except Exception: from muutils.json_serialize.dataclass_transform_mock import ( dataclass_transform, ) Self = TypeVar("Self") T_SerializeableDataclass = TypeVar( "T_SerializeableDataclass", bound="SerializableDataclass" ) class CantGetTypeHintsWarning(UserWarning): "special warning for when we can't get type hints" pass class ZanjMissingWarning(UserWarning): "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" pass _zanj_loading_needs_import: bool = True "flag to keep track of if we have successfully imported ZANJ" def zanj_register_loader_serializable_dataclass( cls: typing.Type[T_SerializeableDataclass], ): """Register a serializable dataclass with the ZANJ import this allows `ZANJ().read()` to load the class and not just return plain dicts # TODO: there is some duplication here with register_loader_handler """ global _zanj_loading_needs_import if _zanj_loading_needs_import: try: from zanj.loading import ( # type: ignore[import] # pyright: ignore[reportMissingImports] LoaderHandler, # pyright: ignore[reportUnknownVariableType] register_loader_handler, # pyright: ignore[reportUnknownVariableType] ) except ImportError: # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter # warnings.warn( # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", # ZanjMissingWarning, # ) return _format: str = f"{cls.__name__}(SerializableDataclass)" lh: LoaderHandler = LoaderHandler( # pyright: ignore[reportPossiblyUnboundVariable] check=lambda json_item, path=None, z=None: ( # type: ignore isinstance(json_item, dict) and _FORMAT_KEY in json_item and json_item[_FORMAT_KEY].startswith(_format) ), load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore uid=_format, source_pckg=cls.__module__, desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", ) register_loader_handler(lh) # pyright: ignore[reportPossiblyUnboundVariable] return lh _DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN _DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT class FieldIsNotInitOrSerializeWarning(UserWarning): pass def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: SerializableField | str, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> bool: """given a dataclass, check the field matches the type hint this function is written to `SerializableDataclass.validate_field_type` # Parameters: - `self : SerializableDataclass` `SerializableDataclass` instance - `field : SerializableField | str` field to validate, will get from `self.__dataclass_fields__` if an `str` - `on_typecheck_error : ErrorMode` what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) # Returns: - `bool` if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` """ on_typecheck_error = ErrorMode.from_any(on_typecheck_error) # get field _field: SerializableField if isinstance(field, str): _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] else: _field = field # do nothing case if not _field.assert_type: return True # if field is not `init` or not `serialize`, skip but warn # TODO: how to handle fields which are not `init` or `serialize`? if not _field.init or not _field.serialize: warnings.warn( f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", FieldIsNotInitOrSerializeWarning, ) return True assert isinstance(_field, SerializableField), ( f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" ) # get field type hints try: field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] except KeyError as e: on_typecheck_error.process( ( f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" + f"{get_cls_type_hints(self.__class__) = }\n" + f"Python version is {sys.version_info = }. You can:\n" + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" + " - use python 3.9.x or higher\n" + " - specify custom type validation function via `custom_typecheck_fn`\n" ), except_cls=TypeError, except_from=e, ) return False # get the value value: Any = getattr(self, _field.name) # validate the type try: type_is_valid: bool # validate the type with the default type validator if _field.custom_typecheck_fn is None: type_is_valid = validate_type(value, field_type_hint) # validate the type with a custom type validator else: type_is_valid = _field.custom_typecheck_fn(field_type_hint) return type_is_valid except Exception as e: on_typecheck_error.process( "exception while validating type: " + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", except_cls=ValueError, except_from=e, ) return False def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> dict[str, bool]: """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field returns a dict of field names to bools, where the bool is if the field type is valid """ on_typecheck_error = ErrorMode.from_any(on_typecheck_error) # if except, bundle the exceptions results: dict[str, bool] = dict() exceptions: dict[str, Exception] = dict() # for each field in the class cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] for field in cls_fields: try: results[field.name] = self.validate_field_type(field, on_typecheck_error) except Exception as e: results[field.name] = False exceptions[field.name] = e # figure out what to do with the exceptions if len(exceptions) > 0: on_typecheck_error.process( f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" + "\n\t" + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), except_cls=ValueError, # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict except_from=list(exceptions.values())[0], ) return results def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> bool: """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" return all( SerializableDataclass__validate_fields_types__dict( self, on_typecheck_error=on_typecheck_error ).values() ) @dataclass_transform( field_specifiers=(serializable_field, SerializableField), ) class SerializableDataclass(abc.ABC): """Base class for serializable dataclasses only for linting and type checking, still need to call `serializable_dataclass` decorator # Usage: ```python @serializable_dataclass class MyClass(SerializableDataclass): a: int b: str ``` and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: >>> my_obj = MyClass(a=1, b="q") >>> s = json.dumps(my_obj.serialize()) >>> s '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' >>> read_obj = MyClass.load(json.loads(s)) >>> read_obj == my_obj True This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default: ```python @serializable_dataclass class NestedClass(SerializableDataclass): x: str y: MyClass act_fun: torch.nn.Module = serializable_field( default=torch.nn.ReLU(), serialization_fn=lambda x: str(x), deserialize_fn=lambda x: getattr(torch.nn, x)(), ) ``` which gives us: >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) >>> s = json.dumps(nc.serialize()) >>> s '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' >>> read_nc = NestedClass.load(json.loads(s)) >>> read_nc == nc True """ def serialize(self) -> dict[str, Any]: "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" raise NotImplementedError( f"decorate {self.__class__ = } with `@serializable_dataclass`" ) @overload @classmethod def load(cls, data: dict[str, Any]) -> Self: ... @overload @classmethod def load(cls, data: Self) -> Self: ... @classmethod def load(cls, data: dict[str, Any] | Self) -> Self: "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") def validate_fields_types( self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR ) -> bool: """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" return SerializableDataclass__validate_fields_types( self, on_typecheck_error=on_typecheck_error ) def validate_field_type( self, field: "SerializableField|str", on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> bool: """given a dataclass, check the field matches the type hint""" return SerializableDataclass__validate_field_type( self, field, on_typecheck_error=on_typecheck_error ) def __eq__(self, other: Any) -> bool: return dc_eq(self, other) def __hash__(self) -> int: "hashes the json-serialized representation of the class" return hash(json.dumps(self.serialize())) def diff( self, other: "SerializableDataclass", of_serialized: bool = False ) -> dict[str, Any]: """get a rich and recursive diff between two instances of a serializable dataclass ```python >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) {'b': {'self': 2, 'other': 3}} >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} ``` # Parameters: - `other : SerializableDataclass` other instance to compare against - `of_serialized : bool` if true, compare serialized data and not raw values (defaults to `False`) # Returns: - `dict[str, Any]` # Raises: - `ValueError` : if the instances are not of the same type - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` """ # match types if type(self) is not type(other): raise ValueError( f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" ) # initialize the diff result diff_result: dict = {} # if they are the same, return the empty diff try: if self == other: return diff_result except Exception: pass # if we are working with serialized data, serialize the instances if of_serialized: ser_self: JSONdict = self.serialize() ser_other: JSONdict = other.serialize() # for each field in the class for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # skip fields that are not for comparison if not field.compare: continue # get values field_name: str = field.name self_value = getattr(self, field_name) other_value = getattr(other, field_name) # if the values are both serializable dataclasses, recurse if isinstance(self_value, SerializableDataclass) and isinstance( other_value, SerializableDataclass ): nested_diff: dict = self_value.diff( other_value, of_serialized=of_serialized ) if nested_diff: diff_result[field_name] = nested_diff # only support serializable dataclasses elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( other_value ): raise ValueError("Non-serializable dataclass is not supported") else: # get the values of either the serialized or the actual values if of_serialized: self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] else: self_value_s = self_value other_value_s = other_value # compare the values if not array_safe_eq(self_value_s, other_value_s): diff_result[field_name] = {"self": self_value, "other": other_value} # return the diff result return diff_result def update_from_nested_dict(self, nested_dict: dict[str, Any]): """update the instance from a nested dict, useful for configuration from command line args # Parameters: - `nested_dict : dict[str, Any]` nested dict to update the instance with """ for field in dataclasses.fields(self): # type: ignore[arg-type] field_name: str = field.name self_value = getattr(self, field_name) if field_name in nested_dict: if isinstance(self_value, SerializableDataclass): self_value.update_from_nested_dict(nested_dict[field_name]) else: setattr(self, field_name, nested_dict[field_name]) def __copy__(self) -> "SerializableDataclass": "deep copy by serializing and loading the instance to json" return self.__class__.load(json.loads(json.dumps(self.serialize()))) def __deepcopy__(self, memo: dict) -> "SerializableDataclass": "deep copy by serializing and loading the instance to json" return self.__class__.load(json.loads(json.dumps(self.serialize()))) # cache this so we don't have to keep getting it # TODO: are the types hashable? does this even make sense? @functools.lru_cache(typed=True) def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: "cached typing.get_type_hints for a class" return typing.get_type_hints(cls) def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: "helper function to get type hints for a class" cls_type_hints: dict[str, Any] try: cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore if len(cls_type_hints) == 0: cls_type_hints = typing.get_type_hints(cls) if len(cls_type_hints) == 0: raise ValueError(f"empty type hints for {cls.__name__ = }") except (TypeError, NameError, ValueError) as e: raise TypeError( f"Cannot get type hints for {cls = }\n" + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] + f" {e = }" ) from e return cls_type_hints class KWOnlyError(NotImplementedError): "kw-only dataclasses are not supported in python <3.9" pass class FieldError(ValueError): "base class for field errors" pass class NotSerializableFieldException(FieldError): "field is not a `SerializableField`" pass class FieldSerializationError(FieldError): "error while serializing a field" pass class FieldLoadingError(FieldError): "error while loading a field" pass class FieldTypeMismatchError(FieldError, TypeError): "error when a field type does not match the type hint" pass @dataclass_transform( field_specifiers=(serializable_field, SerializableField), ) def serializable_dataclass( # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it _cls=None, # type: ignore *, init: bool = True, repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, methods_no_override: list[str] | None = None, **kwargs: Any, ) -> Any: """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. Examines PEP 526 `__annotations__` to determine fields. If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. ```python @serializable_dataclass(kw_only=True) class Myclass(SerializableDataclass): a: int b: str ``` ```python >>> Myclass(a=1, b="q").serialize() {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} ``` # Parameters: - `_cls : _type_` class to decorate. don't pass this arg, just use this as a decorator (defaults to `None`) - `init : bool` whether to add an `__init__` method *(passed to dataclasses.dataclass)* (defaults to `True`) - `repr : bool` whether to add a `__repr__` method *(passed to dataclasses.dataclass)* (defaults to `True`) - `order : bool` whether to add rich comparison methods *(passed to dataclasses.dataclass)* (defaults to `False`) - `unsafe_hash : bool` whether to add a `__hash__` method *(passed to dataclasses.dataclass)* (defaults to `False`) - `frozen : bool` whether to make the class frozen *(passed to dataclasses.dataclass)* (defaults to `False`) - `properties_to_serialize : Optional[list[str]]` which properties to add to the serialized data dict **SerializableDataclass only** (defaults to `None`) - `register_handler : bool` if true, register the class with ZANJ for loading **SerializableDataclass only** (defaults to `True`) - `on_typecheck_error : ErrorMode` what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false **SerializableDataclass only** - `on_typecheck_mismatch : ErrorMode` what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` **SerializableDataclass only** - `methods_no_override : list[str]|None` list of methods that should not be overridden by the decorator by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence **SerializableDataclass only** (defaults to `None`) - `**kwargs` *(passed to dataclasses.dataclass)* # Returns: - `_type_` the decorated class # Raises: - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this - `NotSerializableFieldException` : if a field is not a `SerializableField` - `FieldSerializationError` : if there is an error serializing a field - `AttributeError` : if a property is not found on the class - `FieldLoadingError` : if there is an error loading a field """ # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: on_typecheck_error = ErrorMode.from_any(on_typecheck_error) on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) if properties_to_serialize is None: _properties_to_serialize: list = list() else: _properties_to_serialize = properties_to_serialize def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]: # Modify the __annotations__ dictionary to replace regular fields with SerializableField for field_name, field_type in cls.__annotations__.items(): field_value = getattr(cls, field_name, None) if not isinstance(field_value, SerializableField): if isinstance(field_value, dataclasses.Field): # Convert the field to a SerializableField while preserving properties field_value = SerializableField.from_Field(field_value) else: # Create a new SerializableField field_value = serializable_field() setattr(cls, field_name, field_value) # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy if sys.version_info < (3, 10): if "kw_only" in kwargs: if kwargs["kw_only"] == True: # noqa: E712 raise KWOnlyError( "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" ) else: del kwargs["kw_only"] # call `dataclasses.dataclass` to set some stuff up cls = dataclasses.dataclass( # type: ignore[call-overload] cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, **kwargs, ) # copy these to the class cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] # ====================================================================== # define `serialize` func # done locally since it depends on args to the decorator # ====================================================================== def serialize(self: Any) -> dict[str, Any]: result: dict[str, Any] = { _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" } # for each field in the class for field in dataclasses.fields(self): # type: ignore[arg-type] # need it to be our special SerializableField if not isinstance(field, SerializableField): raise NotSerializableFieldException( f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " f"but a {type(field)} " "this state should be inaccessible, please report this bug!" ) # try to save it if field.serialize: value: Any = None # init before try in case getattr raises try: # get the val value = getattr(self, field.name) # if it is a serializable dataclass, serialize it if isinstance(value, SerializableDataclass): value = value.serialize() # if the value has a serialization function, use that if hasattr(value, "serialize") and callable(value.serialize): # pyright: ignore[reportAttributeAccessIssue] value = value.serialize() # pyright: ignore[reportAttributeAccessIssue] # if the field has a serialization function, use that # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! elif field.serialization_fn: value = field.serialization_fn(value) # store the value in the result result[field.name] = value except Exception as e: raise FieldSerializationError( "\n".join( [ f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", f"{field = }", f"{value or '' = }", f"{self = }", ] ) ) from e # store each property if we can get it for prop in self._properties_to_serialize: if hasattr(cls, prop): value = getattr(self, prop) result[prop] = value else: raise AttributeError( f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" + f"but it is in {self._properties_to_serialize = }" + f"\n{self = }" ) return result # ====================================================================== # define `load` func # done locally since it depends on args to the decorator # ====================================================================== # mypy thinks this isnt a classmethod @classmethod # type: ignore[misc] def load( cls: type[T_SerializeableDataclass], data: dict[str, Any] | T_SerializeableDataclass, ) -> T_SerializeableDataclass: # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ if isinstance(data, cls): return data assert isinstance(data, typing.Mapping), ( f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" ) cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) # initialize dict for keeping what we will pass to the constructor ctor_kwargs: dict[str, Any] = dict() # iterate over the fields of the class # mypy doesn't recognize @dataclass_transform for dataclasses.fields() # https://github.com/python/mypy/issues/16241 for field in dataclasses.fields(cls): # type: ignore[arg-type] # check if the field is a SerializableField assert isinstance(field, SerializableField), ( f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" ) # check if the field is in the data and if it should be initialized if (field.name in data) and field.init: # get the value, we will be processing it value: Any = data[field.name] # get the type hint for the field field_type_hint: Any = cls_type_hints.get(field.name, None) # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set if field.deserialize_fn: # if it has a deserialization function, use that value = field.deserialize_fn(value) elif field.loading_fn: # if it has a loading function, use that value = field.loading_fn(data) elif ( field_type_hint is not None and hasattr(field_type_hint, "load") and callable(field_type_hint.load) ): # if no loading function but has a type hint with a load method, use that if isinstance(value, dict): value = field_type_hint.load(value) else: raise FieldLoadingError( f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" ) else: # assume no loading needs to happen, keep `value` as-is pass # store the value in the constructor kwargs ctor_kwargs[field.name] = value # create a new instance of the class with the constructor kwargs output: T_SerializeableDataclass = cls(**ctor_kwargs) # validate the types of the fields if needed if on_typecheck_mismatch != ErrorMode.IGNORE: fields_valid: dict[str, bool] = ( SerializableDataclass__validate_fields_types__dict( output, on_typecheck_error=on_typecheck_error, ) ) # if there are any fields that are not valid, raise an error if not all(fields_valid.values()): msg: str = ( f"Type mismatch in fields of {cls.__name__}:\n" + "\n".join( [ f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" for k, v in fields_valid.items() if not v ] ) ) on_typecheck_mismatch.process( msg, except_cls=FieldTypeMismatchError ) # return the new instance return output _methods_no_override: set[str] if methods_no_override is None: _methods_no_override = set() else: _methods_no_override = set(methods_no_override) if _methods_no_override - { "__eq__", "serialize", "load", "validate_fields_types", }: warnings.warn( f"Unknown methods in `methods_no_override`: {_methods_no_override = }" ) # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments if "serialize" not in _methods_no_override: # type is `Callable[[T], dict]` cls.serialize = serialize # type: ignore[attr-defined, method-assign] if "load" not in _methods_no_override: # type is `Callable[[dict], T]` cls.load = load # type: ignore[attr-defined, method-assign, assignment] if "validate_field_type" not in _methods_no_override: # type is `Callable[[T, ErrorMode], bool]` cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined, method-assign] if "__eq__" not in _methods_no_override: # type is `Callable[[T, T], bool]` cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] # Register the class with ZANJ if register_handler: zanj_register_loader_serializable_dataclass(cls) return cls if _cls is None: return wrap else: return wrap(_cls) ``````{ end_of_file="muutils/json_serialize/serializable_dataclass.py" } ``````{ path="muutils/json_serialize/serializable_field.py" } """extends `dataclasses.Field` for use with `SerializableDataclass` In particular, instead of using `dataclasses.field`, use `serializable_field` to define fields in a `SerializableDataclass`. You provide information on how the field should be serialized and loaded (as well as anything that goes into `dataclasses.field`) when you define the field, and the `SerializableDataclass` will automatically use those functions. """ from __future__ import annotations import dataclasses import sys import types from typing import Any, Callable, Optional, Union, overload, TypeVar # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access class SerializableField(dataclasses.Field): """extension of `dataclasses.Field` with additional serialization properties""" __slots__ = ( # from dataclasses.Field.__slots__ "name", "type", "default", "default_factory", "repr", "hash", "init", "compare", "doc", "metadata", "kw_only", "_field_type", # Private: not to be used by user code. # new ones "serialize", "serialization_fn", "loading_fn", "deserialize_fn", # new alternative to loading_fn "assert_type", "custom_typecheck_fn", ) def __init__( self, default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, default_factory: Union[ Callable[[], Any], dataclasses._MISSING_TYPE ] = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, doc: str | None = None, # TODO: add field for custom comparator (such as serializing) metadata: Optional[types.MappingProxyType] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, ): # TODO: should we do this check, or assume the user knows what they are doing? if init and not serialize: raise ValueError("Cannot have init=True and serialize=False") # need to assemble kwargs in this hacky way so as not to upset type checking super_kwargs: dict[str, Any] = dict( default=default, default_factory=default_factory, init=init, repr=repr, hash=hash, compare=compare, kw_only=kw_only, ) if metadata is not None: super_kwargs["metadata"] = metadata else: super_kwargs["metadata"] = types.MappingProxyType({}) # only pass `doc` to super if python >=3.14 if sys.version_info >= (3, 14): super_kwargs["doc"] = doc # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy if sys.version_info < (3, 10): if super_kwargs["kw_only"] == True: # noqa: E712 raise ValueError("kw_only is not supported in python >=3.9") else: del super_kwargs["kw_only"] # actually init the super class super().__init__(**super_kwargs) # type: ignore[call-arg] # init doc if python <3.14 if sys.version_info < (3, 14): self.doc: str | None = doc # now init the new fields self.serialize: bool = serialize self.serialization_fn: Optional[Callable[[Any], Any]] = serialization_fn if loading_fn is not None and deserialize_fn is not None: raise ValueError( "Cannot pass both loading_fn and deserialize_fn, pass only one. ", "`loading_fn` is the older interface and takes the dict of the class, ", "`deserialize_fn` is the new interface and takes only the field's value.", ) self.loading_fn: Optional[Callable[[Any], Any]] = loading_fn self.deserialize_fn: Optional[Callable[[Any], Any]] = deserialize_fn self.assert_type: bool = assert_type self.custom_typecheck_fn: Optional[Callable[[type], bool]] = custom_typecheck_fn @classmethod def from_Field(cls, field: "dataclasses.Field[Any]") -> "SerializableField": """copy all values from a `dataclasses.Field` to new `SerializableField`""" return cls( default=field.default, default_factory=field.default_factory, init=field.init, repr=field.repr, hash=field.hash, compare=field.compare, doc=getattr(field, "doc", None), # `doc` added in python <3.14 metadata=field.metadata, kw_only=getattr(field, "kw_only", dataclasses.MISSING), # for python <3.9 serialize=field.repr, # serialize if it's going to be repr'd serialization_fn=None, loading_fn=None, deserialize_fn=None, ) Sfield_T = TypeVar("Sfield_T") @overload def serializable_field( # only `default_factory` is provided *_args: Any, default_factory: Callable[[], Sfield_T], default: dataclasses._MISSING_TYPE = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, doc: str | None = None, metadata: Optional[types.MappingProxyType] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any, ) -> Sfield_T: ... @overload def serializable_field( # only `default` is provided *_args: Any, default: Sfield_T, default_factory: dataclasses._MISSING_TYPE = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, doc: str | None = None, metadata: Optional[types.MappingProxyType] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any, ) -> Sfield_T: ... @overload def serializable_field( # both `default` and `default_factory` are MISSING *_args: Any, default: dataclasses._MISSING_TYPE = dataclasses.MISSING, default_factory: dataclasses._MISSING_TYPE = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, doc: str | None = None, metadata: Optional[types.MappingProxyType] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any, ) -> Any: ... def serializable_field( # general implementation *_args: Any, default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, doc: str | None = None, metadata: Optional[types.MappingProxyType] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any, ) -> Any: """Create a new `SerializableField` ``` default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING, default_factory: Callable[[], Sfield_T] | dataclasses._MISSING_TYPE = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged metadata: types.MappingProxyType | None = None, kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, # ---------------------------------------------------------------------- # new in `SerializableField`, not in `dataclasses.Field` serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, ``` # new Parameters: - `serialize`: whether to serialize this field when serializing the class' - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field. - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking. # Gotchas: - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: ```python class MyClass: my_field: int = serializable_field( serialization_fn=lambda x: str(x), loading_fn=lambda x["my_field"]: int(x) ) ``` using `deserialize_fn` instead: ```python class MyClass: my_field: int = serializable_field( serialization_fn=lambda x: str(x), deserialize_fn=lambda x: int(x) ) ``` In the above code, `my_field` is an int but will be serialized as a string. note that if not using ZANJ, and you have a class inside a container, you MUST provide `serialization_fn` and `loading_fn` to serialize and load the container. ZANJ will automatically do this for you. # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test """ assert len(_args) == 0, f"unexpected positional arguments: {_args}" if "description" in kwargs: import warnings warnings.warn( "`description` is deprecated, use `doc` instead", DeprecationWarning, ) if doc is not None: err_msg: str = f"cannot pass both `doc` and `description`: {doc=}, {kwargs['description']=}" raise ValueError(err_msg) doc = kwargs.pop("description") return SerializableField( default=default, default_factory=default_factory, init=init, repr=repr, hash=hash, compare=compare, doc=doc, metadata=metadata, kw_only=kw_only, serialize=serialize, serialization_fn=serialization_fn, deserialize_fn=deserialize_fn, assert_type=assert_type, custom_typecheck_fn=custom_typecheck_fn, **kwargs, ) ``````{ end_of_file="muutils/json_serialize/serializable_field.py" } ``````{ path="muutils/json_serialize/types.py" } """base types, lets us avoid import cycles""" from __future__ import annotations from typing import TYPE_CHECKING, List, Literal, Union, Tuple if TYPE_CHECKING: from muutils.json_serialize.util import JSONitem BaseType = Union[ bool, int, float, str, None, ] Hashableitem = Union[BaseType, Tuple["Hashableitem", ...]] _FORMAT_KEY: Literal["__muutils_format__"] = "__muutils_format__" _REF_KEY: Literal["$ref"] = "$ref" # TypedDicts for serialized set/frozenset - using Total=False workaround for 3.8 compat # These are used by @overload signatures for better type narrowing try: from typing import TypedDict except ImportError: from typing_extensions import TypedDict class _SerializedSet(TypedDict): """TypedDict for serialized set objects.""" __muutils_format__: Literal["set"] data: List["JSONitem"] class _SerializedFrozenset(TypedDict): """TypedDict for serialized frozenset objects.""" __muutils_format__: Literal["frozenset"] data: List["JSONitem"] ``````{ end_of_file="muutils/json_serialize/types.py" } ``````{ path="muutils/json_serialize/util.py" } """utilities for json_serialize""" from __future__ import annotations import dataclasses import functools import inspect import sys import typing import warnings from typing import Any, Callable, Iterable, TypeVar, Union from muutils.json_serialize.types import BaseType, Hashableitem if typing.TYPE_CHECKING: pass _NUMPY_WORKING: bool try: _NUMPY_WORKING = True except ImportError: warnings.warn("numpy not found, cannot serialize numpy arrays!") _NUMPY_WORKING = False # pyright: reportExplicitAny=false # At type-checking time, include array serialization types to avoid nominal type errors # This avoids superfluous imports at runtime # if TYPE_CHECKING: # from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta # JSONitem = Union[ # BaseType, # typing.Sequence["JSONitem"], # typing.Dict[str, "JSONitem"], # SerializedArrayWithMeta, # NumericList, # ] # else: JSONitem = Union[ BaseType, typing.Sequence["JSONitem"], typing.Dict[str, "JSONitem"], # TODO: figure this out # "_SerializedSet", # "_SerializedFrozenset", ] JSONdict = typing.Dict[str, JSONitem] # TODO: this bit is very broken # or if python version <3.9 if typing.TYPE_CHECKING or sys.version_info < (3, 9): MonoTuple = typing.Sequence else: class MonoTuple: # pyright: ignore[reportUnreachable] """tuple type hint, but for a tuple of any length with all the same type""" __slots__ = () def __new__(cls, *args, **kwargs): raise TypeError("Type MonoTuple cannot be instantiated.") def __init_subclass__(cls, *args, **kwargs): raise TypeError(f"Cannot subclass {cls.__module__}") # idk why mypy thinks there is no such function in typing @typing._tp_cache # type: ignore def __class_getitem__(cls, params): if getattr(params, "__origin__", None) == typing.Union: return typing.GenericAlias(tuple, (params, Ellipsis)) elif isinstance(params, type): typing.GenericAlias(tuple, (params, Ellipsis)) # test if has len and is iterable elif isinstance(params, Iterable): if len(params) == 0: return tuple elif len(params) == 1: return typing.GenericAlias(tuple, (params[0], Ellipsis)) else: raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") # TYPING: we allow `Any` here because the container is... universal class UniversalContainer: """contains everything -- `x in UniversalContainer()` is always True""" def __contains__(self, x: Any) -> bool: # pyright: ignore[reportAny] return True def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] """checks if `x` is a `namedtuple` credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple """ t: type = type(x) # pyright: ignore[reportUnknownVariableType, reportAny] b: tuple[type, ...] = t.__bases__ if len(b) != 1 or (b[0] is not tuple): return False f: Any = getattr(t, "_fields", None) if not isinstance(f, tuple): return False # fine that the type is unknown -- that's what we want to check return all(isinstance(n, str) for n in f) # pyright: ignore[reportUnknownVariableType] T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn") def try_catch( func: Callable[..., T_FuncTryCatchReturn], ) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: """wraps the function to catch exceptions, returns serialized error message on exception returned func will return normal result on success, or error message on exception """ @functools.wraps(func) def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # pyright: ignore[reportAny] try: return func(*args, **kwargs) except Exception as e: return f"{e.__class__.__name__}: {e}" return newfunc # TYPING: can we get rid of any of these? def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: # pyright: ignore[reportAny] if isinstance(obj, typing.Mapping): return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] elif isinstance(obj, (bool, int, float, str)): return obj elif isinstance(obj, (tuple, list, Iterable)): return tuple(_recursive_hashify(v) for v in obj) # pyright: ignore[reportUnknownVariableType] else: if force: return str(obj) # pyright: ignore[reportAny] else: raise ValueError(f"cannot hashify:\n{obj}") class SerializationException(Exception): pass def string_as_lines(s: str | None) -> list[str]: """for easier reading of long strings in json, split up by newlines sort of like how jupyter notebooks do it """ if s is None: return list() else: return s.splitlines(keepends=False) def safe_getsource(func: Callable[..., Any]) -> list[str]: try: return string_as_lines(inspect.getsource(func)) except Exception as e: return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") # credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises def array_safe_eq(a: Any, b: Any) -> bool: # pyright: ignore[reportAny] """check if two objects are equal, account for if numpy arrays or torch tensors""" if a is b: return True if type(a) is not type(b): # pyright: ignore[reportAny] return False if ( str(type(a)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] and str(type(b)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] ) or ( str(type(a)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] and str(type(b)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] ): return (a == b).all() # pyright: ignore[reportAny] if ( str(type(a)) == "" # pyright: ignore[reportUnknownArgumentType, reportAny] and str(type(b)) == "" # pyright: ignore[reportUnknownArgumentType, reportAny] ): return a.equals(b) # pyright: ignore[reportAny] if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): if len(a) == 0 and len(b) == 0: return True return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): return len(a) == len(b) and all( # pyright: ignore[reportUnknownArgumentType] array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) for k1, k2 in zip(a.keys(), b.keys()) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] ) try: return bool(a == b) # pyright: ignore[reportAny] except (TypeError, ValueError) as e: warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") return NotImplemented # type: ignore[return-value] # TYPING: see what can be done about so many `Any`s here def dc_eq( dc1: Any, # pyright: ignore[reportAny] dc2: Any, # pyright: ignore[reportAny] except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False, ) -> bool: """ checks if two dataclasses which (might) hold numpy arrays are equal # Parameters: - `dc1`: the first dataclass - `dc2`: the second dataclass - `except_when_class_mismatch: bool` if `True`, will throw `TypeError` if the classes are different. if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` (default: `False`) - `false_when_class_mismatch: bool` only relevant if `except_when_class_mismatch` is `False`. if `True`, will return `False` if the classes are different. if `False`, will attempt to compare the fields. - `except_when_field_mismatch: bool` only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. if `True`, will throw `AttributeError` if the fields are different. (default: `False`) # Returns: - `bool`: True if the dataclasses are equal, False otherwise # Raises: - `TypeError`: if the dataclasses are of different classes - `AttributeError`: if the dataclasses have different fields ``` [START] ▼ ┌─────────────┐ │ dc1 is dc2? │───Yes───► (True) └──────┬──────┘ │No ▼ ┌───────────────┐ │ classes match?│───Yes───► [compare field values] ───► (True/False) └──────┬────────┘ │No ▼ ┌────────────────────────────┐ │ except_when_class_mismatch?│───Yes───► { raise TypeError } └─────────────┬──────────────┘ │No ▼ ┌────────────────────────────┐ │ false_when_class_mismatch? │───Yes───► (False) └─────────────┬──────────────┘ │No ▼ ┌────────────────────────────┐ │ except_when_field_mismatch?│───No────► [compare field values] └─────────────┬──────────────┘ │Yes ▼ ┌───────────────┐ │ fields match? │───Yes───► [compare field values] └──────┬────────┘ │No ▼ { raise AttributeError } ``` """ if dc1 is dc2: return True if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportAny] if except_when_class_mismatch: # if the classes don't match, raise an error raise TypeError( f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny] ) if false_when_class_mismatch: # return False immediately without attempting field comparison return False # classes don't match but we'll try to compare fields anyway if except_when_field_mismatch: dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny] dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny] fields_match: bool = set(dc1_fields) == set(dc2_fields) if not fields_match: # if the fields don't match, raise an error raise AttributeError( f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" ) return all( array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny] for fld in dataclasses.fields(dc1) # pyright: ignore[reportAny] if fld.compare ) ``````{ end_of_file="muutils/json_serialize/util.py" } ``````{ path="muutils/logger/__init__.py" } """(deprecated) experimenting with logging utilities""" import warnings from muutils.logger.logger import Logger from muutils.logger.loggingstream import LoggingStream from muutils.logger.simplelogger import SimpleLogger from muutils.logger.timing import TimerContext warnings.warn( DeprecationWarning( "muutils.logger is no longer maintained. Consider using [trnbl](https://github.com/mivanit/trnbl) instead." ) ) __all__ = [ # submodules "exception_context", "headerfuncs", "log_util", "logger", "loggingstream", "simplelogger", "timing", # imports "Logger", "LoggingStream", "SimpleLogger", "TimerContext", ] ``````{ end_of_file="muutils/logger/__init__.py" } ``````{ path="muutils/logger/exception_context.py" } from __future__ import annotations import json from types import TracebackType from typing import Protocol from muutils.json_serialize import json_serialize class WritableStream(Protocol): """Protocol for objects that support write operations.""" def write(self, __s: str) -> int: ... class ExceptionContext: """context manager which catches all exceptions happening while the context is open, `.write()` the exception trace to the given stream, and then raises the exception for example: ```python errorfile = open('error.log', 'w') with ExceptionContext(errorfile): # do something that might throw an exception # if it does, the exception trace will be written to errorfile # and then the exception will be raised ``` """ def __init__(self, stream: WritableStream) -> None: self.stream: WritableStream = stream def __enter__(self) -> ExceptionContext: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None, ) -> bool: if exc_type is not None: self.stream.write( json.dumps( json_serialize( { "exc_type": exc_type, "exc_value": exc_value, "exc_traceback": exc_traceback, } ) ) ) return False return True ``````{ end_of_file="muutils/logger/exception_context.py" } ``````{ path="muutils/logger/headerfuncs.py" } from __future__ import annotations import json from typing import Any, Mapping, Protocol from muutils.json_serialize import json_serialize # takes message, level, other data, and outputs message with appropriate header # HeaderFunction = Callable[[str, int, Any], str] class HeaderFunction(Protocol): def __call__(self, msg: Any, lvl: int, **kwargs: Any) -> str: ... def md_header_function( msg: Any, lvl: int, stream: str | None = None, indent_lvl: str = " ", extra_indent: str = "", **kwargs: Any, ) -> str: """standard header function. will output - `# {msg}` for levels in [0, 9] - `## {msg}` for levels in [10, 19], and so on - `[{stream}] # {msg}` for a non-`None` stream, with level headers as before - `!WARNING! [{stream}] {msg}` for level in [-9, -1] - `!!WARNING!! [{stream}] {msg}` for level in [-19, -10] and so on """ stream_prefix: str = "" if stream is not None: stream_prefix = f"[{stream}] " lvl_div_10: int = lvl // 10 msg_processed: str if isinstance(msg, Mapping): msg_processed = ", ".join([f"{k}: {json_serialize(v)}" for k, v in msg.items()]) else: msg_processed = json.dumps(json_serialize(msg)) if lvl >= 0: return f"{extra_indent}{indent_lvl * (lvl_div_10 - 1)}{stream_prefix}#{'#' * lvl_div_10 if lvl else ''} {msg_processed}" else: exclamation_pts: str = "!" * (abs(lvl) // 10) return f"{extra_indent}{exclamation_pts}WARNING{exclamation_pts} {stream_prefix} {msg_processed}" HEADER_FUNCTIONS: dict[str, HeaderFunction] = { "md": md_header_function, } ``````{ end_of_file="muutils/logger/headerfuncs.py" } ``````{ path="muutils/logger/log_util.py" } from __future__ import annotations from typing import Any, TypeVar from muutils.jsonlines import jsonl_load_log T_StreamValue = TypeVar("T_StreamValue") def get_any_from_stream( stream: list[dict[str, T_StreamValue]], key: str ) -> T_StreamValue: """get the first value of a key from a stream. errors if not found""" for msg in stream: if key in msg: return msg[key] raise KeyError(f"key '{key}' not found in stream") def gather_log(file: str) -> dict[str, list[dict[str, Any]]]: """gathers and sorts all streams from a log""" data: list[dict[str, Any]] = jsonl_load_log(file) output: dict[str, list[dict[str, Any]]] = dict() for item in data: stream: str = item.get("_stream", "default") if stream not in output: output[stream] = list() output[stream].append(item) return output def gather_stream( file: str, stream: str, ) -> list[dict[str, Any]]: """gets all entries from a specific stream in a log file""" data: list[dict[str, Any]] = jsonl_load_log(file) output: list[dict[str, Any]] = list() for item in data: # select for the stream if ("_stream" in item) and (item["_stream"] == stream): output.append(item) return output def gather_val( file: str, stream: str, keys: tuple[str, ...], allow_skip: bool = True, ) -> list[list[Any]]: """gather specific keys from a specific stream in a log file example: if "log.jsonl" has contents: ```jsonl {"a": 1, "b": 2, "c": 3, "_stream": "s1"} {"a": 4, "b": 5, "c": 6, "_stream": "s1"} {"a": 7, "b": 8, "c": 9, "_stream": "s2"} ``` then `gather_val("log.jsonl", "s1", ("a", "b"))` will return ```python [ [1, 2], [4, 5] ] ``` """ data: list[dict[str, Any]] = jsonl_load_log(file) output: list[list[Any]] = list() for item in data: # select for the stream if ("_stream" in item) and (item["_stream"] == stream): # select for the keys if all(k in item for k in keys): output.append(list(item[k] for k in keys)) elif not allow_skip: raise ValueError(f"missing keys '{keys = }' in '{item = }'") return output ``````{ end_of_file="muutils/logger/log_util.py" } ``````{ path="muutils/logger/logger.py" } """logger with streams & levels, and a timer context manager - `SimpleLogger` is an extremely simple logger that can write to both console and a file - `Logger` class handles levels in a slightly different way than default python `logging`, and also has "streams" which allow for different sorts of output in the same logger this was mostly made with training models in mind and storing both metadata and loss - `TimerContext` is a context manager that can be used to time the duration of a block of code """ from __future__ import annotations import json import time import typing from functools import partial from typing import Any, Callable, Sequence from muutils.json_serialize import JSONitem, json_serialize from muutils.logger.exception_context import ExceptionContext from muutils.logger.headerfuncs import HEADER_FUNCTIONS, HeaderFunction from muutils.logger.loggingstream import LoggingStream from muutils.logger.simplelogger import AnyIO, SimpleLogger # pylint: disable=arguments-differ, bad-indentation, trailing-whitespace, trailing-newlines, unnecessary-pass, consider-using-with, use-dict-literal def decode_level(level: int) -> str: if not isinstance(level, int): raise TypeError(f"level must be int, got {type(level) = } {level = }") if level < -255: return f"FATAL_ERROR({level})" elif level < 0: return f"WARNING({level})" else: return f"INFO({level})" # todo: add a context which catches and logs all exceptions class Logger(SimpleLogger): """logger with more features, including log levels and streams # Parameters: - `log_path : str | None` default log file path (defaults to `None`) - `log_file : AnyIO | None` default log io, should have a `.write()` method (pass only this or `log_path`, not both) (defaults to `None`) - `timestamp : bool` whether to add timestamps to every log message (under the `_timestamp` key) (defaults to `True`) - `default_level : int` default log level for streams/messages that don't specify a level (defaults to `0`) - `console_print_threshold : int` log level at which to print to the console, anything greater will not be printed unless overridden by `console_print` (defaults to `50`) - `level_header : HeaderFunction` function for formatting log messages when printing to console (defaults to `HEADER_FUNCTIONS["md"]`) - `keep_last_msg_time : bool` whether to keep the last message time (defaults to `True`) # Raises: - `ValueError` : _description_ """ def __init__( self, log_path: str | None = None, log_file: AnyIO | None = None, default_level: int = 0, console_print_threshold: int = 50, level_header: HeaderFunction = HEADER_FUNCTIONS["md"], streams: dict[str | None, LoggingStream] | Sequence[LoggingStream] = (), keep_last_msg_time: bool = True, # junk args timestamp: bool = True, **kwargs: Any, ) -> None: # junk arg checking # ================================================== if len(kwargs) > 0: raise ValueError(f"unrecognized kwargs: {kwargs}") if not timestamp: raise ValueError( "timestamp must be True -- why would you not want timestamps?" ) # timing # ================================================== # timing compares self._keep_last_msg_time: bool = keep_last_msg_time # TODO: handle per stream? self._last_msg_time: float | None = time.time() # basic setup # ================================================== # init BaseLogger super().__init__(log_file=log_file, log_path=log_path, timestamp=timestamp) # level-related self._console_print_threshold: int = console_print_threshold self._default_level: int = default_level # set up streams self._streams: dict[str | None, LoggingStream] = ( streams if isinstance(streams, typing.Mapping) else {s.name: s for s in streams} ) # default error stream if "error" not in self._streams: self._streams["error"] = LoggingStream( "error", aliases={ "err", "except", "Exception", "exception", "exceptions", "errors", }, ) # check alias duplicates alias_set: set[str | None] = set() for stream in self._streams.values(): for alias in stream.aliases: if alias in alias_set: raise ValueError(f"alias {alias} is already in use") alias_set.add(alias) # add aliases for stream in tuple(self._streams.values()): for alias in stream.aliases: if alias not in self._streams: self._streams[alias] = stream # print formatting self._level_header: HeaderFunction = level_header print({k: str(v) for k, v in self._streams.items()}) def _exception_context( self, stream: str = "error", # level: int = -256, # **kwargs, ) -> ExceptionContext: import sys s: LoggingStream = self._streams[stream] handler = s.handler if s.handler is not None else sys.stderr return ExceptionContext(stream=handler) def log( self, msg: JSONitem = None, *, lvl: int | None = None, stream: str | None = None, console_print: bool = False, extra_indent: str = "", **kwargs: Any, ) -> None: """logging function ### Parameters: - `msg : JSONitem` message (usually string or dict) to be logged - `lvl : int | None` level of message (lower levels are more important) (defaults to `None`) - `console_print : bool` override `console_print_threshold` setting (defaults to `False`) - `stream : str | None` whether to log to a stream (defaults to `None`), which logs to the default `None` stream (defaults to `None`) """ # add to known stream names if not present if stream not in self._streams: self._streams[stream] = LoggingStream(stream) # set default level to either global or stream-specific default level # ======================================== if lvl is None: if stream is None: lvl = self._default_level else: if self._streams[stream].default_level is not None: lvl = self._streams[stream].default_level else: lvl = self._default_level assert lvl is not None, "lvl should not be None at this point" # print to console with formatting # ======================================== _printed: bool = False if console_print or (lvl <= self._console_print_threshold): # add some formatting print( self._level_header( msg=msg, lvl=lvl, stream=stream, extra_indent=extra_indent, ) ) # store the last message time if self._last_msg_time is not None: self._last_msg_time = time.time() _printed = True # convert and add data # ======================================== # converting to dict msg_dict: dict[str, Any] if not isinstance(msg, typing.Mapping): msg_dict = {"_msg": msg} else: msg_dict = dict(typing.cast(typing.Mapping[str, Any], msg)) # level+stream metadata if lvl is not None: msg_dict["_lvl"] = lvl # msg_dict["_stream"] = stream # moved to LoggingStream # extra data in kwargs if len(kwargs) > 0: msg_dict["_kwargs"] = kwargs # add default contents (timing, etc) msg_dict = { **{k: v() for k, v in self._streams[stream].default_contents.items()}, **msg_dict, } # write # ======================================== logfile_msg: str = json.dumps(json_serialize(msg_dict)) + "\n" if ( (stream is None) or (stream not in self._streams) or (self._streams[stream].handler is None) ): # write to the main log file if no stream is specified self._log_file_handle.write(logfile_msg) else: # otherwise, write to the stream-specific file s_handler: AnyIO | None = self._streams[stream].handler if s_handler is not None: s_handler.write(logfile_msg) else: raise ValueError( f"stream handler is None! something in the logging stream setup is wrong:\n{self}" ) # if it was important enough to print, flush all streams if _printed: self.flush_all() def log_elapsed_last( self, lvl: int | None = None, stream: str | None = None, console_print: bool = True, **kwargs: Any, ) -> None: """logs the time elapsed since the last message was printed to the console (in any stream)""" if self._last_msg_time is None: raise ValueError("no last message time!") else: self.log( {"elapsed_time": round(time.time() - self._last_msg_time, 6)}, lvl=(lvl if lvl is not None else self._console_print_threshold), stream=stream, console_print=console_print, **kwargs, ) def flush_all(self): """flush all streams""" self._log_file_handle.flush() for stream in self._streams.values(): if stream.handler is not None: stream.handler.flush() def __getattr__(self, stream: str) -> Callable[..., Any]: if stream.startswith("_"): raise AttributeError(f"invalid stream name {stream} (no underscores)") return partial(self.log, stream=stream) def __getitem__(self, stream: str) -> Callable[..., Any]: return partial(self.log, stream=stream) def __call__(self, *args: Any, **kwargs: Any) -> None: self.log(*args, **kwargs) ``````{ end_of_file="muutils/logger/logger.py" } ``````{ path="muutils/logger/loggingstream.py" } from __future__ import annotations import sys import time from dataclasses import dataclass, field from typing import Any, Callable if sys.version_info >= (3, 12): from typing import override else: from typing_extensions import override from muutils.logger.simplelogger import AnyIO, NullIO from muutils.misc import sanitize_fname @dataclass class LoggingStream: """properties of a logging stream - `name: str` name of the stream - `aliases: set[str]` aliases for the stream (calls to these names will be redirected to this stream. duplicate alises will result in errors) TODO: perhaps duplicate alises should result in duplicate writes? - `file: str|bool|AnyIO|None` file to write to - if `None`, will write to standard log - if `True`, will write to `name + ".log"` - if `False` will "write" to `NullIO` (throw it away) - if a string, will write to that file - if a fileIO type object, will write to that object - `default_level: int|None` default level for this stream - `default_contents: dict[str, Callable[[], Any]]` default contents for this stream - `last_msg: tuple[float, Any]|None` last message written to this stream (timestamp, message) """ name: str | None aliases: set[str | None] = field(default_factory=set) file: str | bool | AnyIO | None = None default_level: int | None = None default_contents: dict[str, Callable[[], Any]] = field(default_factory=dict) handler: AnyIO | None = None # TODO: implement last-message caching # last_msg: tuple[float, Any]|None = None def make_handler(self) -> AnyIO | None: if self.file is None: return None elif isinstance(self.file, str): # if its a string, open a file return open( self.file, "w", encoding="utf-8", ) elif isinstance(self.file, bool): # if its a bool and true, open a file with the same name as the stream (in the current dir) # TODO: make this happen in the same dir as the main logfile? if self.file: return open( # type: ignore[return-value] f"{sanitize_fname(self.name)}.log.jsonl", "w", encoding="utf-8", ) else: return NullIO() else: # if its neither, check it has `.write()` and `.flush()` methods if ( ( not hasattr(self.file, "write") or (not callable(self.file.write)) or (not hasattr(self.file, "flush")) or (not callable(self.file.flush)) ) or (not hasattr(self.file, "close")) or (not callable(self.file.close)) ): raise ValueError(f"stream {self.name} has invalid handler {self.file}") # ignore type check because we know it has a .write() method, # assume the user knows what they're doing return self.file # type: ignore def __post_init__(self): self.aliases = set(self.aliases) if any(x.startswith("_") for x in self.aliases if x is not None): raise ValueError( "stream names or aliases cannot start with an underscore, sorry" ) self.aliases.add(self.name) self.default_contents["_timestamp"] = time.time self.default_contents["_stream"] = lambda: self.name self.handler = self.make_handler() def __del__(self): if self.handler is not None: self.handler.flush() self.handler.close() @override def __str__(self): return f"LoggingStream(name={self.name}, aliases={self.aliases}, file={self.file}, default_level={self.default_level}, default_contents={self.default_contents})" ``````{ end_of_file="muutils/logger/loggingstream.py" } ``````{ path="muutils/logger/simplelogger.py" } from __future__ import annotations import json import sys import time import typing from typing import Any, TextIO, Union from muutils.json_serialize import JSONitem, json_serialize class NullIO: """null IO class""" def __init__(self) -> None: pass def write(self, msg: str) -> int: """write to nothing! this throws away the message""" return len(msg) def flush(self) -> None: """flush nothing! this is a no-op""" pass def close(self) -> None: """close nothing! this is a no-op""" pass AnyIO = Union[TextIO, NullIO] class SimpleLogger: """logs training data to a jsonl file""" def __init__( self, log_path: str | None = None, log_file: AnyIO | None = None, timestamp: bool = True, ): self._timestamp: bool = timestamp self._log_path: str | None = log_path self._log_file_handle: AnyIO if (log_path is None) and (log_file is None): print( "[logger_internal] # no log file specified, will only write to console", sys.stderr, ) self._log_file_handle = sys.stdout elif (log_path is not None) and (log_file is not None): raise ValueError( "cannot specify both log_path and log_file, use streams in `SimpleLogger`" ) else: # now exactly one of the two is None if log_file is not None: self._log_file_handle = log_file else: assert log_path is not None self._log_file_handle = open(log_path, "w", encoding="utf-8") def log(self, msg: JSONitem, *, console_print: bool = False, **kwargs: Any) -> None: """log a message to the log file, and optionally to the console""" if console_print: print(msg) msg_dict: dict[str, Any] if not isinstance(msg, typing.Mapping): msg_dict = {"_msg": msg} else: msg_dict = dict(typing.cast(typing.Mapping[str, Any], msg)) if self._timestamp: msg_dict["_timestamp"] = time.time() if len(kwargs) > 0: msg_dict["_kwargs"] = kwargs self._log_file_handle.write(json.dumps(json_serialize(msg_dict)) + "\n") ``````{ end_of_file="muutils/logger/simplelogger.py" } ``````{ path="muutils/logger/timing.py" } from __future__ import annotations import time from types import TracebackType from typing import Literal class TimerContext: """context manager for timing code""" def __init__(self) -> None: self.start_time: float self.end_time: float self.elapsed_time: float def __enter__(self) -> "TimerContext": self.start_time = time.time() return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> Literal[False]: self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time return False def filter_time_str(time: str) -> str: """assuming format `h:mm:ss`, clips off the hours if its 0""" if (len(time) == 7) and (time[0] == "0"): return time[3:] else: return time class ProgressEstimator: """estimates progress and can give a progress bar""" def __init__( self, n_total: int, pbar_fill: str = "█", pbar_empty: str = " ", pbar_bounds: tuple[str, str] = ("|", "|"), ): self.n_total: int = n_total self.starttime: float = time.time() self.pbar_fill: str = pbar_fill self.pbar_empty: str = pbar_empty self.pbar_bounds: tuple[str, str] = pbar_bounds self.total_str_len: int = len(str(n_total)) def get_timing_raw(self, i: int) -> dict[str, float]: """returns dict(elapsed, per_iter, remaining, percent)""" elapsed: float = time.time() - self.starttime per_iter: float = elapsed / i return dict( elapsed=elapsed, per_iter=per_iter, remaining=(self.n_total - i) * per_iter, percent=i / self.n_total, ) def get_pbar( self, i: int, width: int = 30, ) -> str: """returns a progress bar""" percent_filled: float = i / self.n_total # round to nearest integer n_filled: int = int(round(percent_filled * width)) return "".join( [ self.pbar_bounds[0], self.pbar_fill * n_filled, self.pbar_empty * (width - n_filled), self.pbar_bounds[1], ] ) def get_progress_default(self, i: int) -> str: """returns a progress string""" timing_raw: dict[str, float] = self.get_timing_raw(i) percent_str: str = str(int(timing_raw["percent"] * 100)).ljust(2) # TODO: get_progress_default # iters_str: str = f"{str(i).ljust(self.total_str_len)}/{self.n_total}" # timing_str: str return f"{percent_str}% {self.get_pbar(i)}" ``````{ end_of_file="muutils/logger/timing.py" } ``````{ path="muutils/math/__init__.py" } __all__ = [ "bins", "matrix_powers", ] ``````{ end_of_file="muutils/math/__init__.py" } ``````{ path="muutils/math/bins.py" } from __future__ import annotations from dataclasses import dataclass from functools import cached_property from typing import Literal import numpy as np from jaxtyping import Float @dataclass(frozen=True) class Bins: n_bins: int = 32 start: float = 0 stop: float = 1.0 scale: Literal["lin", "log"] = "log" _log_min: float = 1e-3 _zero_in_small_start_log: bool = True @cached_property def edges(self) -> Float[np.ndarray, "n_bins+1"]: if self.scale == "lin": return np.linspace(self.start, self.stop, self.n_bins + 1) elif self.scale == "log": if self.start < 0: raise ValueError( f"start must be positive for log scale, got {self.start}" ) if self.start == 0: return np.concatenate( [ # pyright: ignore[reportUnknownArgumentType] np.array([0]), np.logspace( np.log10(self._log_min), # pyright: ignore[reportAny] np.log10(self.stop), # pyright: ignore[reportAny] self.n_bins, ), ] ) elif self.start < self._log_min and self._zero_in_small_start_log: return np.concatenate( [ # pyright: ignore[reportUnknownArgumentType] np.array([0]), np.logspace( np.log10(self.start), # pyright: ignore[reportAny] np.log10(self.stop), # pyright: ignore[reportAny] self.n_bins, ), ] ) else: return np.logspace( # pyright: ignore[reportUnknownVariableType] np.log10(self.start), # pyright: ignore[reportAny] np.log10(self.stop), # pyright: ignore[reportAny] self.n_bins + 1, ) else: raise ValueError(f"Invalid scale {self.scale}, expected lin or log") @cached_property def centers(self) -> Float[np.ndarray, "n_bins"]: return (self.edges[:-1] + self.edges[1:]) / 2 def changed_n_bins_copy(self, n_bins: int) -> "Bins": return Bins( n_bins=n_bins, start=self.start, stop=self.stop, scale=self.scale, _log_min=self._log_min, _zero_in_small_start_log=self._zero_in_small_start_log, ) ``````{ end_of_file="muutils/math/bins.py" } ``````{ path="muutils/math/matrix_powers.py" } from __future__ import annotations from typing import Any, List, Sequence, TYPE_CHECKING import numpy as np from jaxtyping import Float, Int if TYPE_CHECKING: pass def matrix_powers( A: Float[np.ndarray, "n n"], powers: Sequence[int], ) -> Float[np.ndarray, "n_powers n n"]: """Compute multiple powers of a matrix efficiently. Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers. # Parameters: - `A : Float[np.ndarray, "n n"]` Square matrix to exponentiate - `powers : Sequence[int]` List of powers to compute (non-negative integers) # Returns: - `dict[int, Float[np.ndarray, "n n"]]` Dictionary mapping each requested power to the corresponding matrix power """ dim_n: int = A.shape[0] assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }" powers_np: Int[np.ndarray, "n_powers_unique"] = np.array( sorted(set(powers)), dtype=int ) n_powers_unique: int = len(powers_np) if n_powers_unique < 1: raise ValueError(f"No powers requested: {powers = }") output: Float[np.ndarray, "n_powers_unique n n"] = np.full( (n_powers_unique, dim_n, dim_n), fill_value=np.nan, dtype=A.dtype, ) # Find the maximum power to compute max_power: int = max(powers_np) # Precompute all powers of 2 up to the largest power needed # This forms our basis for binary decomposition powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {} powers_of_two[0] = np.eye(dim_n, dtype=A.dtype) powers_of_two[1] = A.copy() # Compute powers of 2: A^2, A^4, A^8, ... p: int = 1 while p < max_power: if p <= max_power: A_power_p = powers_of_two[p] powers_of_two[p * 2] = A_power_p @ A_power_p p = p * 2 # For each requested power, compute it using the powers of 2 for p_idx, power in enumerate(powers_np): # Decompose power into sum of powers of 2 temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy() temp_power: int = power p_temp: int = 1 while temp_power > 0: if temp_power % 2 == 1: temp_result = temp_result @ powers_of_two[p_temp] temp_power = temp_power // 2 p_temp *= 2 output[p_idx] = temp_result return output # BUG: breaks with integer matrices??? # TYPING: jaxtyping hints not working here, separate file for torch implementation? def matrix_powers_torch( A: Any, # : Float["torch.Tensor", "n n"], powers: Sequence[int], ) -> Any: # Float["torch.Tensor", "n_powers n n"]: """Compute multiple powers of a matrix efficiently. Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers. # Parameters: - `A : Float[torch.Tensor, "n n"]` Square matrix to exponentiate - `powers : Sequence[int]` List of powers to compute (non-negative integers) # Returns: - `Float[torch.Tensor, "n_powers n n"]` Tensor containing the requested matrix powers stacked along the first dimension # Raises: - `ValueError` : If no powers are requested or if A is not a square matrix """ import torch if len(A.shape) != 2 or A.shape[0] != A.shape[1]: raise ValueError(f"Matrix must be square, but got {A.shape = }") dim_n: int = A.shape[0] # Get unique powers and sort them unique_powers: List[int] = sorted(set(powers)) n_powers_unique: int = len(unique_powers) powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor( unique_powers, dtype=torch.int64, device=A.device ) if n_powers_unique < 1: raise ValueError(f"No powers requested: {powers = }") output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full( (n_powers_unique, dim_n, dim_n), float("nan"), dtype=A.dtype, device=A.device, ) # Find the maximum power to compute max_power: int = int(powers_tensor.max().item()) # Precompute all powers of 2 up to the largest power needed # This forms our basis for binary decomposition powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {} powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device) powers_of_two[1] = A.clone() # Compute powers of 2: A^2, A^4, A^8, ... p: int = 1 while p < max_power: if p <= max_power: A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p] powers_of_two[p * 2] = A_power_p @ A_power_p p = p * 2 # For each requested power, compute it using the powers of 2 for p_idx, power in enumerate(unique_powers): # Decompose power into sum of powers of 2 temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone() temp_power: int = power p_temp: int = 1 while temp_power > 0: if temp_power % 2 == 1: temp_result = temp_result @ powers_of_two[p_temp] temp_power = temp_power // 2 p_temp *= 2 output[p_idx] = temp_result return output ``````{ end_of_file="muutils/math/matrix_powers.py" } ``````{ path="muutils/misc/__init__.py" } """miscellaneous utilities - `stable_hash` for hashing that is stable across runs - `muutils.misc.sequence` for sequence manipulation, applying mappings, and string-like operations on lists - `muutils.misc.string` for sanitizing things for filenames, adjusting docstrings, and converting dicts to filenames - `muutils.misc.numerical` for turning numbers into nice strings and back - `muutils.misc.freezing` for freezing things - `muutils.misc.classes` for some weird class utilities """ # pyright: reportPrivateUsage=false from muutils.misc.hashing import stable_hash from muutils.misc.sequence import ( WhenMissing, empty_sequence_if_attr_false, flatten, list_split, list_join, apply_mapping, apply_mapping_chain, ) from muutils.misc.string import ( sanitize_name, sanitize_fname, sanitize_identifier, dict_to_filename, dynamic_docstring, ) from muutils.misc.numerical import ( shorten_numerical_to_str, str_to_numeric, _SHORTEN_MAP, ) from muutils.misc.freezing import ( FrozenDict, FrozenList, freeze, ) from muutils.misc.classes import ( is_abstract, get_all_subclasses, isinstance_by_type_name, IsDataclass, get_hashable_eq_attrs, dataclass_set_equals, ) __all__ = [ # submodules "classes", "freezing", "func", "hashing", "numerical", "sequence", "string", # imports "stable_hash", "WhenMissing", "empty_sequence_if_attr_false", "flatten", "list_split", "list_join", "apply_mapping", "apply_mapping_chain", "sanitize_name", "sanitize_fname", "sanitize_identifier", "dict_to_filename", "dynamic_docstring", "shorten_numerical_to_str", "str_to_numeric", "_SHORTEN_MAP", "FrozenDict", "FrozenList", "freeze", "is_abstract", "get_all_subclasses", "isinstance_by_type_name", "IsDataclass", "get_hashable_eq_attrs", "dataclass_set_equals", ] ``````{ end_of_file="muutils/misc/__init__.py" } ``````{ path="muutils/misc/b64_decode.py" } from sys import argv from pathlib import Path from base64 import b64decode if __name__ == "__main__": input_file: Path = Path(argv[1]) out: Path = Path(argv[2]) input_text: str = input_file.read_text().replace("\n", "") out.write_bytes(b64decode(input_text)) # pyright: ignore[reportUnusedCallResult] ``````{ end_of_file="muutils/misc/b64_decode.py" } ``````{ path="muutils/misc/classes.py" } from __future__ import annotations from typing import ( Iterable, Any, Protocol, ClassVar, runtime_checkable, ) from muutils.misc.sequence import flatten def is_abstract(cls: type) -> bool: """ Returns if a class is abstract. """ if not hasattr(cls, "__abstractmethods__"): return False # an ordinary class elif len(cls.__abstractmethods__) == 0: # type: ignore[invalid-argument-type] # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] return False # a concrete implementation of an abstract class else: return True # an abstract class def get_all_subclasses(class_: type, include_self=False) -> set[type]: """ Returns a set containing all child classes in the subclass graph of `class_`. I.e., includes subclasses of subclasses, etc. # Parameters - `include_self`: Whether to include `class_` itself in the returned set - `class_`: Superclass # Development Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy. """ subs: set[type] = set( flatten( get_all_subclasses(sub, include_self=True) for sub in class_.__subclasses__() if sub is not None ) ) if include_self: subs.add(class_) return subs def isinstance_by_type_name(o: object, type_name: str): """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself. This is a hacky function intended to circumvent the need to import a type into a module. It is susceptible to type name collisions. # Parameters `o`: Object (not the type itself) whose type to interrogate `type_name`: The string returned by `type_.__name__`. Generic types are not supported, only types that would appear in `type_.__mro__`. """ return type_name in {s.__name__ for s in type(o).__mro__} # dataclass magic # -------------------------------------------------------------------------------- @runtime_checkable class IsDataclass(Protocol): # Generic type for any dataclass instance # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass __dataclass_fields__: ClassVar[dict[str, Any]] # pyright: ignore[reportExplicitAny] def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]: # pyright: ignore[reportExplicitAny] """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen. """ # TYPING: ty gives @Todo here return ( # type: ignore[invalid-return-type] *( getattr(dc, fld.name) for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values()) ), type(dc), ) def dataclass_set_equals( coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass] ) -> bool: """Compares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can't be placed in sets since they're not hashable. Collections of them may be compared using this function. """ return {get_hashable_eq_attrs(x) for x in coll1} == { get_hashable_eq_attrs(y) for y in coll2 } ``````{ end_of_file="muutils/misc/classes.py" } ``````{ path="muutils/misc/freezing.py" } from __future__ import annotations from typing import Any, Iterable, NoReturn, SupportsIndex, TypeVar, overload class FrozenDict(dict): # type: ignore[type-arg] def __setitem__(self, key: Any, value: Any) -> NoReturn: raise AttributeError("dict is frozen") def __delitem__(self, key: Any) -> NoReturn: raise AttributeError("dict is frozen") class FrozenList(list): # type: ignore[type-arg] def __setitem__(self, index: SupportsIndex | slice, value: Any) -> NoReturn: raise AttributeError("list is frozen") def __delitem__(self, index: SupportsIndex | slice) -> NoReturn: raise AttributeError("list is frozen") def append(self, value: Any) -> NoReturn: raise AttributeError("list is frozen") def extend(self, iterable: Iterable[Any]) -> NoReturn: raise AttributeError("list is frozen") def insert(self, index: SupportsIndex, value: Any) -> NoReturn: raise AttributeError("list is frozen") def remove(self, value: Any) -> NoReturn: raise AttributeError("list is frozen") def pop(self, index: SupportsIndex = -1) -> NoReturn: raise AttributeError("list is frozen") def clear(self) -> NoReturn: raise AttributeError("list is frozen") FreezeMe = TypeVar("FreezeMe") @overload def freeze(instance: dict) -> FrozenDict: ... @overload def freeze(instance: list) -> FrozenList: ... @overload def freeze(instance: tuple) -> tuple: ... @overload def freeze(instance: set) -> frozenset: ... @overload def freeze(instance: FreezeMe) -> FreezeMe: ... def freeze(instance: Any) -> Any: """recursively freeze an object in-place so that its attributes and elements cannot be changed messy in the sense that sometimes the object is modified in place, but you can't rely on that. always use the return value. the [gelidum](https://github.com/diegojromerolopez/gelidum/) package is a more complete implementation of this idea """ # mark as frozen if hasattr(instance, "_IS_FROZEN"): if instance._IS_FROZEN: return instance # try to mark as frozen try: instance._IS_FROZEN = True # type: ignore[attr-defined] except AttributeError: pass # skip basic types, weird things, or already frozen things if isinstance(instance, (bool, int, float, str, bytes)): pass elif isinstance(instance, (type(None), type(Ellipsis))): pass elif isinstance(instance, (FrozenList, FrozenDict, frozenset)): pass # handle containers elif isinstance(instance, list): for i in range(len(instance)): instance[i] = freeze(instance[i]) instance = FrozenList(instance) elif isinstance(instance, tuple): instance = tuple(freeze(item) for item in instance) elif isinstance(instance, set): instance = frozenset({freeze(item) for item in instance}) # type: ignore[assignment] elif isinstance(instance, dict): for key, value in instance.items(): instance[key] = freeze(value) instance = FrozenDict(instance) # handle custom classes else: # set everything in the __dict__ to frozen instance.__dict__ = freeze(instance.__dict__) # type: ignore[assignment] # create a new class which inherits from the original class class FrozenClass(instance.__class__): # type: ignore[name-defined] def __setattr__(self, name: str, value: Any) -> NoReturn: raise AttributeError("class is frozen") FrozenClass.__name__ = f"FrozenClass__{instance.__class__.__name__}" FrozenClass.__module__ = instance.__class__.__module__ FrozenClass.__doc__ = instance.__class__.__doc__ # set the instance's class to the new class try: instance.__class__ = FrozenClass except TypeError as e: raise TypeError( f"Cannot freeze:\n{instance = }\n{instance.__class__ = }\n{FrozenClass = }" ) from e return instance ``````{ end_of_file="muutils/misc/freezing.py" } ``````{ path="muutils/misc/func.py" } from __future__ import annotations import functools import sys from types import CodeType import warnings from typing import Any, Callable, Tuple, cast, TypeVar # TODO: we do a lot of type weirdness here that basedpyright doesn't like # pyright: reportInvalidTypeForm=false try: if sys.version_info >= (3, 11): # 3.11+ from typing import Unpack, TypeVarTuple, ParamSpec else: # 3.9+ from typing_extensions import Unpack, TypeVarTuple, ParamSpec # type: ignore[assignment] except ImportError: warnings.warn( "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work" ) ParamSpec = TypeVar # type: ignore Unpack = Any # type: ignore TypeVarTuple = TypeVar # type: ignore from muutils.errormode import ErrorMode warnings.warn("muutils.misc.func is experimental, use with caution") ReturnType = TypeVar("ReturnType") T_kwarg = TypeVar("T_kwarg") T_process_in = TypeVar("T_process_in") T_process_out = TypeVar("T_process_out") FuncParams = ParamSpec("FuncParams") FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap") def process_kwarg( kwarg_name: str, processor: Callable[[T_process_in], T_process_out], ) -> Callable[ [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType] ]: """Decorator that applies a processor to a keyword argument. The underlying function is expected to have a keyword argument (with name `kwarg_name`) of type `T_out`, but the caller provides a value of type `T_in` that is converted via `processor`. # Parameters: - `kwarg_name : str` The name of the keyword argument to process. - `processor : Callable[[T_in], T_out]` A callable that converts the input value (`T_in`) into the type expected by the function (`T_out`). # Returns: - A decorator that converts a function of type `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`) into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`). """ def decorator( func: Callable[FuncParamsPreWrap, ReturnType], ) -> Callable[FuncParams, ReturnType]: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> ReturnType: if kwarg_name in kwargs: # Convert the caller’s value (of type T_in) to T_out kwargs[kwarg_name] = processor(kwargs[kwarg_name]) return func(*args, **kwargs) # type: ignore[arg-type] return cast(Callable[FuncParams, ReturnType], wrapper) return decorator # TYPING: error: Argument of type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" cannot be assigned to parameter of type "() -> ReturnType@process_kwarg" # Type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" is not assignable to type "() -> ReturnType@process_kwarg" # Extra parameter "kwarg_name" # Extra parameter "validator" (reportArgumentType) @process_kwarg("action", ErrorMode.from_any) # pyright: ignore[reportArgumentType] def validate_kwarg( kwarg_name: str, validator: Callable[[T_kwarg], bool], description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT, ) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: """Decorator that validates a specific keyword argument. # Parameters: - `kwarg_name : str` The name of the keyword argument to validate. - `validator : Callable[[Any], bool]` A callable that returns True if the keyword argument is valid. - `description : str | None` A message template if validation fails. - `action : str` Either `"raise"` (default) or `"warn"`. # Returns: - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` A decorator that validates the keyword argument. # Modifies: - If validation fails and `action=="warn"`, emits a warning. Otherwise, raises a ValueError. # Usage: ```python @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}") def my_func(x: int) -> int: return x assert my_func(x=1) == 1 ``` # Raises: - `ValueError` if validation fails and `action == "raise"`. """ def decorator( func: Callable[FuncParams, ReturnType], ) -> Callable[FuncParams, ReturnType]: @functools.wraps(func) def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] if kwarg_name in kwargs: value: Any = kwargs[kwarg_name] if not validator(value): # ty: ignore[invalid-argument-type] msg: str = ( description.format(kwarg_name=kwarg_name, value=value) if description else f"Validation failed for keyword '{kwarg_name}' with value {value}" ) if action == "warn": warnings.warn(msg, UserWarning) else: raise ValueError(msg) return func(*args, **kwargs) return cast(Callable[FuncParams, ReturnType], wrapper) return decorator def replace_kwarg( kwarg_name: str, check: Callable[[T_kwarg], bool], replacement_value: T_kwarg, replace_if_missing: bool = False, ) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: """Decorator that replaces a specific keyword argument value by identity comparison. # Parameters: - `kwarg_name : str` The name of the keyword argument to replace. - `check : Callable[[T_kwarg], bool]` A callable that returns True if the keyword argument should be replaced. - `replacement_value : T_kwarg` The value to replace with. - `replace_if_missing : bool` If True, replaces the keyword argument even if it's missing. # Returns: - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` A decorator that replaces the keyword argument value. # Modifies: - Updates `kwargs[kwarg_name]` if its value is `default_value`. # Usage: ```python @replace_kwarg("x", None, "default_string") def my_func(*, x: str | None = None) -> str: return x assert my_func(x=None) == "default_string" ``` """ def decorator( func: Callable[FuncParams, ReturnType], ) -> Callable[FuncParams, ReturnType]: @functools.wraps(func) def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] if kwarg_name in kwargs: # TODO: no way to type hint this, I think if check(kwargs[kwarg_name]): # type: ignore[arg-type] kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] elif replace_if_missing and kwarg_name not in kwargs: kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] return func(*args, **kwargs) return cast(Callable[FuncParams, ReturnType], wrapper) return decorator def is_none(value: Any) -> bool: return value is None def always_true(value: Any) -> bool: return True def always_false(value: Any) -> bool: return False def format_docstring( **fmt_kwargs: Any, ) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: """Decorator that formats a function's docstring with the provided keyword arguments.""" def decorator( func: Callable[FuncParams, ReturnType], ) -> Callable[FuncParams, ReturnType]: if func.__doc__ is not None: func.__doc__ = func.__doc__.format(**fmt_kwargs) return func return decorator # TODO: no way to make the type system understand this afaik LambdaArgs = TypeVarTuple("LambdaArgs") LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...]) def typed_lambda( # pyright: ignore[reportUnknownParameterType] fn: Callable[[Unpack[LambdaArgs]], ReturnType], in_types: LambdaArgsTypes, # pyright: ignore[reportInvalidTypeVarUse] out_type: type[ReturnType], ) -> Callable[[Unpack[LambdaArgs]], ReturnType]: """Wraps a lambda function with type hints. # Parameters: - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]` The lambda function to wrap. - `in_types : tuple[type, ...]` Tuple of input types. - `out_type : type[ReturnType]` The output type. # Returns: - `Callable[..., ReturnType]` A new function with annotations matching the given signature. # Usage: ```python add = typed_lambda(lambda x, y: x + y, (int, int), int) assert add(1, 2) == 3 ``` # Raises: - `ValueError` if the number of input types doesn't match the lambda's parameters. """ # it will just error here if fn.__code__ doesn't exist code: CodeType = fn.__code__ # type: ignore[unresolved-attribute] n_params: int = code.co_argcount if len(in_types) != n_params: raise ValueError( f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})" ) param_names: tuple[str, ...] = code.co_varnames[:n_params] annotations: dict[str, type] = { # type: ignore[var-annotated] name: typ for name, typ in zip(param_names, in_types) # type: ignore[arg-type] } annotations["return"] = out_type @functools.wraps(fn) def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: # pyright: ignore[reportUnknownParameterType] return fn(*args) wrapped.__annotations__ = annotations return wrapped ``````{ end_of_file="muutils/misc/func.py" } ``````{ path="muutils/misc/hashing.py" } from __future__ import annotations import base64 import hashlib import json from typing import Any def stable_hash(s: str | bytes) -> int: """Returns a stable hash of the given string. not cryptographically secure, but stable between runs""" # init hash object and update with string s_bytes: bytes if isinstance(s, str): s_bytes = s.encode("utf-8") else: s_bytes = s hash_obj: hashlib._Hash = hashlib.md5(s_bytes) # pyright: ignore[reportPrivateUsage] # get digest and convert to int return int.from_bytes(hash_obj.digest(), "big") def stable_json_dumps(d: Any) -> str: # pyright: ignore[reportAny] return json.dumps( d, sort_keys=True, indent=None, ) def base64_hash(s: str | bytes) -> str: """Returns a base64 representation of the hash of the given string. not cryptographically secure""" s_bytes: bytes if isinstance(s, str): s_bytes = bytes(s, "UTF-8") else: s_bytes = s hash_bytes: bytes = hashlib.md5(s_bytes).digest() hash_b64: str = base64.b64encode(hash_bytes, altchars=b"-_").decode() return hash_b64 ``````{ end_of_file="muutils/misc/hashing.py" } ``````{ path="muutils/misc/numerical.py" } from __future__ import annotations _SHORTEN_MAP: dict[int | float, str] = { 1e3: "K", 1e6: "M", 1e9: "B", 1e12: "t", 1e15: "q", 1e18: "Q", } _SHORTEN_TUPLES: list[tuple[int | float, str]] = sorted( ((val, suffix) for val, suffix in _SHORTEN_MAP.items()), key=lambda x: -x[0], ) _REVERSE_SHORTEN_MAP: dict[str, int | float] = {v: k for k, v in _SHORTEN_MAP.items()} def shorten_numerical_to_str( num: int | float, small_as_decimal: bool = True, precision: int = 1, ) -> str: """shorten a large numerical value to a string 1234 -> 1K precision guaranteed to 1 in 10, but can be higher. reverse of `str_to_numeric` """ # small values are returned as is num_abs: float = abs(num) if num_abs < 1e3: return str(num) # iterate over suffixes from largest to smallest for i, (val, suffix) in enumerate(_SHORTEN_TUPLES): if num_abs > val or i == len(_SHORTEN_TUPLES) - 1: if (num_abs < val * 10) and small_as_decimal: return f"{num / val:.{precision}f}{suffix}" elif num_abs < val * 1e3: return f"{int(round(num / val))}{suffix}" return f"{num:.{precision}f}" def str_to_numeric( quantity: str, mapping: None | bool | dict[str, int | float] = True, ) -> int | float: """Convert a string representing a quantity to a numeric value. The string can represent an integer, python float, fraction, or shortened via `shorten_numerical_to_str`. # Examples: ``` >>> str_to_numeric("5") 5 >>> str_to_numeric("0.1") 0.1 >>> str_to_numeric("1/5") 0.2 >>> str_to_numeric("-1K") -1000.0 >>> str_to_numeric("1.5M") 1500000.0 >>> str_to_numeric("1.2e2") 120.0 ``` """ # check is string if not isinstance(quantity, str): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError( # pyright: ignore[reportUnreachable] f"quantity must be a string, got '{type(quantity) = }' '{quantity = }'" ) # basic int conversion try: quantity_int: int = int(quantity) return quantity_int except ValueError: pass # basic float conversion try: quantity_float: float = float(quantity) return quantity_float except ValueError: pass # mapping _mapping: dict[str, int | float] if mapping is True or mapping is None: _mapping = _REVERSE_SHORTEN_MAP else: _mapping = mapping # type: ignore[assignment] # pyright: ignore[reportAssignmentType] quantity_original: str = quantity quantity = quantity.strip() result: int | float multiplier: int | float = 1 # detect if it has a suffix suffixes_detected: list[bool] = [suffix in quantity for suffix in _mapping] n_suffixes_detected: int = sum(suffixes_detected) if n_suffixes_detected == 0: # no suffix pass elif n_suffixes_detected == 1: # find multiplier for suffix, mult in _mapping.items(): if quantity.endswith(suffix): # remove suffix, store multiplier, and break quantity = quantity[: -len(suffix)].strip() multiplier = mult break else: raise ValueError(f"Invalid suffix in {quantity_original}") else: # multiple suffixes raise ValueError(f"Multiple suffixes detected in {quantity_original}") # fractions if "/" in quantity: try: assert quantity.count("/") == 1, "too many '/'" # split and strip num, den = quantity.split("/") num = num.strip() den = den.strip() num_sign: int = 1 # negative numbers if num.startswith("-"): num_sign = -1 num = num[1:] # assert that both are digits assert num.isdigit() and den.isdigit(), ( "numerator and denominator must be digits" ) # return the fraction result = num_sign * ( int(num) / int(den) ) # this allows for fractions with suffixes, which is weird, but whatever except AssertionError as e: raise ValueError(f"Invalid fraction {quantity_original}: {e}") from e # decimals else: try: result = int(quantity) except ValueError: try: result = float(quantity) except ValueError as e: raise ValueError( f"Invalid quantity {quantity_original} ({quantity})" ) from e return result * multiplier ``````{ end_of_file="muutils/misc/numerical.py" } ``````{ path="muutils/misc/sequence.py" } from __future__ import annotations from typing import ( Iterable, Any, Generator, Callable, Union, ) import typing from typing import ( Literal, Mapping, ) WhenMissing = Literal["except", "skip", "include"] def empty_sequence_if_attr_false( itr: Iterable[Any], attr_owner: Any, attr_name: str, ) -> Iterable[Any]: """Returns `itr` if `attr_owner` has the attribute `attr_name` and it boolean casts to `True`. Returns an empty sequence otherwise. Particularly useful for optionally inserting delimiters into a sequence depending on an `TokenizerElement` attribute. # Parameters: - `itr: Iterable[Any]` The iterable to return if the attribute is `True`. - `attr_owner: Any` The object to check for the attribute. - `attr_name: str` The name of the attribute to check. # Returns: - `itr: Iterable` if `attr_owner` has the attribute `attr_name` and it boolean casts to `True`, otherwise an empty sequence. - `()` an empty sequence if the attribute is `False` or not present. """ return itr if bool(getattr(attr_owner, attr_name, False)) else () def flatten(it: Iterable[Any], levels_to_flatten: int | None = None) -> Generator: """ Flattens an arbitrarily nested iterable. Flattens all iterable data types except for `str` and `bytes`. # Returns Generator over the flattened sequence. # Parameters - `it`: Any arbitrarily nested iterable. - `levels_to_flatten`: Number of levels to flatten by, starting at the outermost layer. If `None`, performs full flattening. """ for x in it: # TODO: swap type check with more general check for __iter__() or __next__() or whatever if ( hasattr(x, "__iter__") and not isinstance(x, (str, bytes)) and (levels_to_flatten is None or levels_to_flatten > 0) ): yield from flatten( x, None if levels_to_flatten is None else levels_to_flatten - 1 ) else: yield x # string-like operations on lists # -------------------------------------------------------------------------------- def list_split(lst: list, val: Any) -> list[list]: """split a list into sublists by `val`. similar to "a_b_c".split("_") ```python >>> list_split([1,2,3,0,4,5,0,6], 0) [[1, 2, 3], [4, 5], [6]] >>> list_split([0,1,2,3], 0) [[], [1, 2, 3]] >>> list_split([1,2,3], 0) [[1, 2, 3]] >>> list_split([], 0) [[]] ``` """ if len(lst) == 0: return [[]] output: list[list] = [ [], ] for x in lst: if x == val: output.append([]) else: output[-1].append(x) return output def list_join(lst: list, factory: Callable) -> list: """add a *new* instance of `factory()` between each element of `lst` ```python >>> list_join([1,2,3], lambda : 0) [1,0,2,0,3] >>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1]) [1, 1600000000.0, 2, 1600000000.1, 3] ``` """ if len(lst) == 0: return [] output: list = [ lst[0], ] for x in lst[1:]: output.append(factory()) output.append(x) return output # applying mappings # -------------------------------------------------------------------------------- _AM_K = typing.TypeVar("_AM_K") _AM_V = typing.TypeVar("_AM_V") def apply_mapping( mapping: Mapping[_AM_K, _AM_V], iter: Iterable[_AM_K], when_missing: WhenMissing = "skip", ) -> list[Union[_AM_K, _AM_V]]: """Given an iterable and a mapping, apply the mapping to the iterable with certain options Gotcha: if `when_missing` is invalid, this is totally fine until a missing key is actually encountered. Note: you can use this with `muutils.kappa.Kappa` if you want to pass a function instead of a dict # Parameters: - `mapping : Mapping[_AM_K, _AM_V]` must have `__contains__` and `__getitem__`, both of which take `_AM_K` and the latter returns `_AM_V` - `iter : Iterable[_AM_K]` the iterable to apply the mapping to - `when_missing : WhenMissing` what to do when a key is missing from the mapping -- this is what distinguishes this function from `map` you can choose from `"skip"`, `"include"` (without converting), and `"except"` (defaults to `"skip"`) # Returns: return type is one of: - `list[_AM_V]` if `when_missing` is `"skip"` or `"except"` - `list[Union[_AM_K, _AM_V]]` if `when_missing` is `"include"` # Raises: - `KeyError` : if the item is missing from the mapping and `when_missing` is `"except"` - `ValueError` : if `when_missing` is invalid """ output: list[Union[_AM_K, _AM_V]] = list() item: _AM_K for item in iter: if item in mapping: output.append(mapping[item]) continue if when_missing == "skip": continue elif when_missing == "include": output.append(item) elif when_missing == "except": raise KeyError(f"item {item} is missing from mapping {mapping}") else: raise ValueError( f"invalid value for {when_missing = }\n{item = }\n{mapping = }" ) return output def apply_mapping_chain( mapping: Mapping[_AM_K, Iterable[_AM_V]], iter: Iterable[_AM_K], when_missing: WhenMissing = "skip", ) -> list[Union[_AM_K, _AM_V]]: """Given an iterable and a mapping, chain the mappings together Gotcha: if `when_missing` is invalid, this is totally fine until a missing key is actually encountered. Note: you can use this with `muutils.kappa.Kappa` if you want to pass a function instead of a dict # Parameters: - `mapping : Mapping[_AM_K, Iterable[_AM_V]]` must have `__contains__` and `__getitem__`, both of which take `_AM_K` and the latter returns `Iterable[_AM_V]` - `iter : Iterable[_AM_K]` the iterable to apply the mapping to - `when_missing : WhenMissing` what to do when a key is missing from the mapping -- this is what distinguishes this function from `map` you can choose from `"skip"`, `"include"` (without converting), and `"except"` (defaults to `"skip"`) # Returns: return type is one of: - `list[_AM_V]` if `when_missing` is `"skip"` or `"except"` - `list[Union[_AM_K, _AM_V]]` if `when_missing` is `"include"` # Raises: - `KeyError` : if the item is missing from the mapping and `when_missing` is `"except"` - `ValueError` : if `when_missing` is invalid """ output: list[Union[_AM_K, _AM_V]] = list() item: _AM_K for item in iter: if item in mapping: output.extend(mapping[item]) continue if when_missing == "skip": continue elif when_missing == "include": output.append(item) elif when_missing == "except": raise KeyError(f"item {item} is missing from mapping {mapping}") else: raise ValueError( f"invalid value for {when_missing = }\n{item = }\n{mapping = }" ) return output ``````{ end_of_file="muutils/misc/sequence.py" } ``````{ path="muutils/misc/string.py" } from __future__ import annotations from typing import Any, Callable, TypeVar from muutils.misc.hashing import stable_hash def sanitize_name( name: str | None, additional_allowed_chars: str = "", replace_invalid: str = "", when_none: str | None = "_None_", leading_digit_prefix: str = "", ) -> str: """sanitize a string, leaving only alphanumerics and `additional_allowed_chars` # Parameters: - `name : str | None` input string - `additional_allowed_chars : str` additional characters to allow, none by default (defaults to `""`) - `replace_invalid : str` character to replace invalid characters with (defaults to `""`) - `when_none : str | None` string to return if `name` is `None`. if `None`, raises an exception (defaults to `"_None_"`) - `leading_digit_prefix : str` character to prefix the string with if it starts with a digit (defaults to `""`) # Returns: - `str` sanitized string """ if name is None: if when_none is None: raise ValueError("name is None") else: return when_none sanitized: str = "" for char in name: if char.isalnum(): sanitized += char elif char in additional_allowed_chars: sanitized += char else: sanitized += replace_invalid if sanitized[0].isdigit(): sanitized = leading_digit_prefix + sanitized return sanitized def sanitize_fname( fname: str | None, replace_invalid: str = "", when_none: str | None = "_None_", leading_digit_prefix: str = "", ) -> str: """sanitize a filename to posix standards - leave only alphanumerics, `_` (underscore), '-' (dash) and `.` (period) """ return sanitize_name( name=fname, additional_allowed_chars="._-", replace_invalid=replace_invalid, when_none=when_none, leading_digit_prefix=leading_digit_prefix, ) def sanitize_identifier( fname: str | None, replace_invalid: str = "", when_none: str | None = "_None_", ) -> str: """sanitize an identifier (variable or function name) - leave only alphanumerics and `_` (underscore) - prefix with `_` if it starts with a digit """ return sanitize_name( name=fname, additional_allowed_chars="_", replace_invalid=replace_invalid, when_none=when_none, leading_digit_prefix="_", ) def dict_to_filename( data: dict[str, Any], format_str: str = "{key}_{val}", separator: str = ".", max_length: int = 255, ): # Convert the dictionary items to a list of strings using the format string formatted_items: list[str] = [ format_str.format(key=k, val=v) for k, v in data.items() # pyright: ignore[reportAny] ] # Join the formatted items using the separator joined_str: str = separator.join(formatted_items) # Remove special characters and spaces sanitized_str: str = sanitize_fname(joined_str) # Check if the length is within limits if len(sanitized_str) <= max_length: return sanitized_str # If the string is too long, generate a hash return f"h_{stable_hash(sanitized_str)}" T_Callable = TypeVar("T_Callable", bound=Callable[..., Any]) def dynamic_docstring(**doc_params: str) -> Callable[[T_Callable], T_Callable]: def decorator(func: T_Callable) -> T_Callable: if func.__doc__: func.__doc__ = getattr(func, "__doc__", "").format(**doc_params) return func return decorator ``````{ end_of_file="muutils/misc/string.py" } ``````{ path="muutils/misc/typing_breakdown.py" } """Parse type checker outputs and generate detailed breakdown of errors by type and file. Usage: python -m muutils.misc.typing_breakdown [OPTIONS] Examples: python -m muutils.misc.typing_breakdown python -m muutils.misc.typing_breakdown --error-dir .meta/.type-errors python -m muutils.misc.typing_breakdown --top-n 15 --output .meta/typing-summary.txt """ from __future__ import annotations import argparse import os import re from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Dict, List, Literal, Tuple def strip_cwd(path: str) -> str: """Strip the current working directory from a file path to make it relative. Args: path: File path (absolute or relative) Returns: Relative path with CWD stripped, or original path if not under CWD """ cwd: str = os.getcwd() # Normalize both paths to handle different separators and resolve symlinks abs_path: str = os.path.abspath(path) abs_cwd: str = os.path.abspath(cwd) # Ensure CWD ends with separator for proper prefix matching if not abs_cwd.endswith(os.sep): abs_cwd += os.sep # Strip CWD prefix if present if abs_path.startswith(abs_cwd): return abs_path[len(abs_cwd) :] return path @dataclass class TypeCheckResult: "results from parsing a type checker output" type_checker: Literal["mypy", "basedpyright", "ty"] by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) by_file: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) # Separate tracking for warnings (used by basedpyright) warnings_by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) warnings_by_file: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) @property def total_errors(self) -> int: "total number of errors across all types, validates they match between type and file dicts" total_by_type: int = sum(self.by_type.values()) total_by_file: int = sum(self.by_file.values()) if total_by_type != total_by_file: err_msg: str = f"Error count mismatch for {self.type_checker}: by_type={total_by_type}, by_file={total_by_file}" raise ValueError(err_msg) return total_by_type def filter_by(self, top_n: int | None) -> TypeCheckResult: "return a copy with errors sorted by count and filtered to top_n items (or all if None)" # Sort by count (descending) sorted_by_type: List[Tuple[str, int]] = sorted( self.by_type.items(), key=lambda x: x[1], reverse=True, ) sorted_by_file: List[Tuple[str, int]] = sorted( self.by_file.items(), key=lambda x: x[1], reverse=True, ) sorted_warnings_by_type: List[Tuple[str, int]] = sorted( self.warnings_by_type.items(), key=lambda x: x[1], reverse=True, ) sorted_warnings_by_file: List[Tuple[str, int]] = sorted( self.warnings_by_file.items(), key=lambda x: x[1], reverse=True, ) # Apply top_n limit if specified if top_n is not None: sorted_by_type = sorted_by_type[:top_n] sorted_by_file = sorted_by_file[:top_n] sorted_warnings_by_type = sorted_warnings_by_type[:top_n] sorted_warnings_by_file = sorted_warnings_by_file[:top_n] # Create new instance with filtered data (dicts maintain insertion order in Python 3.7+) result: TypeCheckResult = TypeCheckResult(type_checker=self.type_checker) result.by_type = dict(sorted_by_type) result.by_file = dict(sorted_by_file) result.warnings_by_type = dict(sorted_warnings_by_type) result.warnings_by_file = dict(sorted_warnings_by_file) return result @property def total_warnings(self) -> int: "total number of warnings across all types" total_by_type: int = sum(self.warnings_by_type.values()) total_by_file: int = sum(self.warnings_by_file.values()) if total_by_type != total_by_file: err_msg: str = f"Warning count mismatch for {self.type_checker}: by_type={total_by_type}, by_file={total_by_file}" raise ValueError(err_msg) return total_by_type def to_toml(self) -> str: "format as TOML-like output" lines: List[str] = [] # Main section with total lines.append(f"[type_errors.{self.type_checker}]") try: lines.append(f"total_errors = {self.total_errors}") except ValueError: lines.append(f"total_errors_by_type = {sum(self.by_type.values())}") lines.append(f"total_errors_by_file = {sum(self.by_file.values())}") lines.append("") # by_type section lines.append(f"[type_errors.{self.type_checker}.by_type]") error_type: str count: int for error_type, count in self.by_type.items(): # Always quote keys lines.append(f'"{error_type}" = {count}') lines.append("") # by_file section lines.append(f"[type_errors.{self.type_checker}.by_file]") file_path: str for file_path, count in self.by_file.items(): # Always quote file paths lines.append(f'"{file_path}" = {count}') # Add warnings sections if there are any warnings if self.warnings_by_type or self.warnings_by_file: lines.append("") lines.append(f"[type_warnings.{self.type_checker}]") try: lines.append(f"total_warnings = {self.total_warnings}") except ValueError: lines.append( f"total_warnings_by_type = {sum(self.warnings_by_type.values())}" ) lines.append( f"total_warnings_by_file = {sum(self.warnings_by_file.values())}" ) lines.append("") # warnings by_type section lines.append(f"[type_warnings.{self.type_checker}.by_type]") warning_type: str for warning_type, count in self.warnings_by_type.items(): lines.append(f'"{warning_type}" = {count}') lines.append("") # warnings by_file section lines.append(f"[type_warnings.{self.type_checker}.by_file]") for file_path, count in self.warnings_by_file.items(): lines.append(f'"{file_path}" = {count}') return "\n".join(lines) def parse_mypy(content: str) -> TypeCheckResult: "parse mypy output: file.py:line: error: message [error-code]" result: TypeCheckResult = TypeCheckResult(type_checker="mypy") pattern: re.Pattern[str] = re.compile( r"^(.+?):\d+: error: .+ \[(.+?)\]", re.MULTILINE ) match: re.Match[str] for match in pattern.finditer(content): file_path: str = match.group(1) error_code: str = match.group(2) result.by_type[error_code] += 1 result.by_file[file_path] += 1 return result def parse_basedpyright(content: str) -> TypeCheckResult: "parse basedpyright output: path on line, then indented errors with (code)" result: TypeCheckResult = TypeCheckResult(type_checker="basedpyright") # Pattern for file paths (lines that start with /) # Pattern for errors: indented line with - error/warning: message (code) # Some diagnostics span multiple lines with (reportCode) on a continuation line current_file: str = "" pending_diagnostic_type: str | None = None # "error" or "warning" waiting for code line: str for line in content.splitlines(): # Check if this is a file path line (starts with / and no leading space) if line and not line.startswith(" ") and line.startswith("/"): current_file = strip_cwd(line.strip()) pending_diagnostic_type = None elif line.strip() and current_file: # Try to match single-line format: " path:line:col - warning: message (reportCode)" match: re.Match[str] | None = re.search( r"\s+.+:\d+:\d+ - (error|warning): .+ \((\w+)\)", line ) if match: diagnostic_type: str = match.group(1) error_code: str = match.group(2) if diagnostic_type == "warning": result.warnings_by_type[error_code] += 1 result.warnings_by_file[current_file] += 1 else: result.by_type[error_code] += 1 result.by_file[current_file] += 1 pending_diagnostic_type = None else: # Check if this is a diagnostic line without code (multi-line format start) diag_match: re.Match[str] | None = re.search( r"\s+.+:\d+:\d+ - (error|warning): ", line ) if diag_match: pending_diagnostic_type = diag_match.group(1) # Check if this is a continuation line with the code elif pending_diagnostic_type: code_match: re.Match[str] | None = re.search(r"\((\w+)\)\s*$", line) if code_match: error_code = code_match.group(1) if pending_diagnostic_type == "warning": result.warnings_by_type[error_code] += 1 result.warnings_by_file[current_file] += 1 else: result.by_type[error_code] += 1 result.by_file[current_file] += 1 pending_diagnostic_type = None return result def parse_ty(content: str) -> TypeCheckResult: "parse ty output: error[error-code]: message then --> file:line:col" result: TypeCheckResult = TypeCheckResult(type_checker="ty") # Pattern for error type: error[code]: or warning[code]: error_pattern: re.Pattern[str] = re.compile( r"^(error|warning)\[(.+?)\]:", re.MULTILINE ) # Pattern for location: --> file:line:col location_pattern: re.Pattern[str] = re.compile( r"^\s+-->\s+(.+?):\d+:\d+", re.MULTILINE ) # Find all errors and their locations errors: List[re.Match[str]] = list(error_pattern.finditer(content)) locations: List[re.Match[str]] = list(location_pattern.finditer(content)) # Match errors with locations (they should be in order) error_match: re.Match[str] for error_match in errors: error_code: str = error_match.group(2) result.by_type[error_code] += 1 # Find the next location after this error error_pos: int = error_match.end() loc_match: re.Match[str] for loc_match in locations: if loc_match.start() > error_pos: file_path: str = loc_match.group(1) result.by_file[file_path] += 1 break return result def extract_summary_line(file_path: Path) -> str: "extract the last non-empty line from a file (typically the summary line)" content: str = file_path.read_text(encoding="utf-8") lines: List[str] = [line.strip() for line in content.splitlines() if line.strip()] return lines[-1] def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: "parse all type checker outputs and generate breakdown" error_path: Path = Path(error_dir) output_path: Path = Path(output_file) output_lines: List[str] = [] # Add header comment with top_n info if top_n is None: output_lines.append("# Showing all errors") else: output_lines.append(f"# Showing top {top_n} errors per category") output_lines.append("") # First, extract summary lines from each type checker checkers_files: List[Tuple[str, str]] = [ ("mypy", "mypy.txt"), ("basedpyright", "basedpyright.txt"), ("ty", "ty.txt"), ] name: str filename: str for name, filename in checkers_files: file_path: Path = error_path / filename summary: str = extract_summary_line(file_path) output_lines.append(f"# {name}: {summary}") output_lines.append("") # Parse each type checker checkers: List[Tuple[str, str, Callable[[str], TypeCheckResult]]] = [ ("mypy", "mypy.txt", parse_mypy), ("basedpyright", "basedpyright.txt", parse_basedpyright), ("ty", "ty.txt", parse_ty), ] parser_fn: Callable[[str], TypeCheckResult] for name, filename, parser_fn in checkers: file_path_: Path = error_path / filename content: str = file_path_.read_text(encoding="utf-8") result: TypeCheckResult = parser_fn(content) # Filter and sort the result filtered_result: TypeCheckResult = result.filter_by(top_n) # Convert to TOML breakdown: str = filtered_result.to_toml() output_lines.append(breakdown) output_lines.append("") # Add blank line between checkers # Write to output file final_output: str = "\n".join(output_lines) output_path.parent.mkdir(parents=True, exist_ok=True) _ = output_path.write_text(final_output, encoding="utf-8") # Also print to stdout print(final_output) if __name__ == "__main__": parser: argparse.ArgumentParser = argparse.ArgumentParser( description="Parse type checker outputs and generate detailed breakdown of errors by type and file", formatter_class=argparse.RawDescriptionHelpFormatter, ) _ = parser.add_argument( "--error-dir", type=str, default=".meta/.type-errors", help="Directory containing type checker output files (default: .meta/.type-errors)", ) _ = parser.add_argument( "--output", "-o", type=str, default=".meta/typing-summary.txt", help="Output file to write summary to (default: .meta/typing-summary.txt)", ) _ = parser.add_argument( "--top-n", "-n", type=str, default="10", help='Number of top items to show in each category (default: 10). Use "all" or negative number for all items.', ) args: argparse.Namespace = parser.parse_args() # Parse top_n value assert isinstance(args.top_n, str) # pyright: ignore[reportAny] top_n_value: int | None if args.top_n.lower() == "all": top_n_value = None else: top_n_int: int = int(args.top_n) top_n_value = None if top_n_int < 0 else top_n_int main(error_dir=args.error_dir, output_file=args.output, top_n=top_n_value) # pyright: ignore[reportAny] ``````{ end_of_file="muutils/misc/typing_breakdown.py" } ``````{ path="muutils/ml/cuda_mem_info.py" } from __future__ import annotations import torch # pyright: reportUnreachable=false, reportUnnecessaryIsInstance=false def _to_cuda_device(device: int | str | torch.device) -> torch.device: """Return a normalized CUDA device object.""" dev: torch.device if isinstance(device, torch.device): dev = device elif isinstance(device, int): dev = torch.device(f"cuda:{device}") elif isinstance(device, str): # Accept forms like "cuda", "cuda:0", or bare index "0" dev = torch.device(device) else: raise TypeError(f"Unsupported device type: {type(device).__name__}") if dev.type != "cuda": raise ValueError(f"Device {dev} is not a CUDA device") return dev def cuda_mem_info(dev: torch.device) -> tuple[int, int]: """Return (free, total) bytes for a CUDA device.""" current_idx: int = torch.cuda.current_device() if dev.index != current_idx: torch.cuda.set_device(dev) free: int total: int free, total = torch.cuda.mem_get_info() torch.cuda.set_device(current_idx) else: free, total = torch.cuda.mem_get_info() return free, total def cuda_memory_used(device: int | str | torch.device = 0) -> int: """Return bytes currently allocated on a CUDA device.""" dev: torch.device = _to_cuda_device(device) free, total = cuda_mem_info(dev) used: int = total - free return used def cuda_memory_fraction(device: int | str | torch.device = 0) -> float: """Return fraction of total memory in use on a CUDA device.""" dev: torch.device = _to_cuda_device(device) free, total = cuda_mem_info(dev) used: int = total - free fraction: float = used / total if total else 0.0 return fraction ``````{ end_of_file="muutils/ml/cuda_mem_info.py" } ``````{ path="muutils/nbutils/__init__.py" } """utilities for working with notebooks - configuring figures mdoes and torch devices: `configure_notebook` - converting them to scripts: `convert_ipynb_to_script` - running them as tests: `run_notebook_tests` - and working with diagrams/LaTeX: `mermaid`, `print_tex` """ from muutils.nbutils.mermaid import mm __all__ = [ # sub-modules "configure_notebook", "convert_ipynb_to_script", "mermaid", "print_tex", "run_notebook_tests", # functions "mm", ] ``````{ end_of_file="muutils/nbutils/__init__.py" } ``````{ path="muutils/nbutils/configure_notebook.py" } """shared utilities for setting up a notebook""" from __future__ import annotations import os import typing import warnings import matplotlib.pyplot as plt # type: ignore[import] class PlotlyNotInstalledWarning(UserWarning): pass # handle plotly importing PLOTLY_IMPORTED: bool try: import plotly.io as pio # type: ignore[import] except ImportError: warnings.warn( "Plotly not installed. Plotly plots will not be available.", PlotlyNotInstalledWarning, ) PLOTLY_IMPORTED = False else: PLOTLY_IMPORTED = True # figure out if we're in a jupyter notebook try: from IPython import get_ipython # type: ignore[import-not-found] IN_JUPYTER = get_ipython() is not None except ImportError: IN_JUPYTER = False # muutils imports from muutils.mlutils import get_device, set_reproducibility # noqa: E402 # handling figures PlottingMode = typing.Literal["ignore", "inline", "widget", "save"] PLOT_MODE: PlottingMode = "inline" CONVERSION_PLOTMODE_OVERRIDE: PlottingMode | None = None FIG_COUNTER: int = 0 FIG_OUTPUT_FMT: str | None = None FIG_NUMBERED_FNAME: str = "figure-{num}" FIG_CONFIG: dict | None = None FIG_BASEPATH: str | None = None CLOSE_AFTER_PLOTSHOW: bool = False MATPLOTLIB_FORMATS = ["pdf", "png", "jpg", "jpeg", "svg", "eps", "ps", "tif", "tiff"] TIKZPLOTLIB_FORMATS = ["tex", "tikz"] class UnknownFigureFormatWarning(UserWarning): pass def universal_savefig(fname: str, fmt: str | None = None) -> None: # try to infer format from fname if fmt is None: fmt = fname.split(".")[-1] if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS): warnings.warn( f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'", UnknownFigureFormatWarning, ) fmt = FIG_OUTPUT_FMT # not sure why linting is throwing an error here if not fname.endswith(fmt): # type: ignore[arg-type] fname += f".{fmt}" if fmt in MATPLOTLIB_FORMATS: plt.savefig(fname, format=fmt, bbox_inches="tight") elif fmt in TIKZPLOTLIB_FORMATS: import tikzplotlib # type: ignore[import] tikzplotlib.save(fname) else: warnings.warn(f"Unknown format '{fmt}', going with matplotlib default") plt.savefig(fname, bbox_inches="tight") def setup_plots( plot_mode: PlottingMode = "inline", fig_output_fmt: str | None = "pdf", fig_numbered_fname: str = "figure-{num}", fig_config: dict | None = None, fig_basepath: str | None = None, close_after_plotshow: bool = False, ) -> None: """Set up plot saving/rendering options""" global \ PLOT_MODE, \ CONVERSION_PLOTMODE_OVERRIDE, \ FIG_COUNTER, \ FIG_OUTPUT_FMT, \ FIG_NUMBERED_FNAME, \ FIG_CONFIG, \ FIG_BASEPATH, \ CLOSE_AFTER_PLOTSHOW # set plot mode, handling override if CONVERSION_PLOTMODE_OVERRIDE is not None: # override if set PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE else: # otherwise use the given plot mode PLOT_MODE = plot_mode FIG_COUNTER = 0 CLOSE_AFTER_PLOTSHOW = close_after_plotshow if PLOT_MODE == "inline": if IN_JUPYTER: ipython = get_ipython() # pyright: ignore[reportPossiblyUnboundVariable] ipython.magic("matplotlib inline") # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] else: raise RuntimeError( f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }" ) return elif PLOT_MODE == "widget": if IN_JUPYTER: ipython = get_ipython() # pyright: ignore[reportPossiblyUnboundVariable] ipython.magic("matplotlib widget") # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] else: # matplotlib outside of jupyter will bring up a new window by default pass return elif PLOT_MODE == "ignore": # disable plotting plt.show = lambda: None # type: ignore[misc] return # everything except saving handled up to this point assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}" FIG_OUTPUT_FMT = fig_output_fmt FIG_NUMBERED_FNAME = fig_numbered_fname FIG_CONFIG = fig_config # set default figure format in rcParams savefig.format plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS: try: import tikzplotlib # type: ignore[import] # noqa: F401 except ImportError: warnings.warn( f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break." ) else: if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS: warnings.warn( f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }' ) # if base path not given, make one if fig_basepath is None: if fig_config is None: # if no config, use the current time from datetime import datetime fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" else: # if config given, convert to string from muutils.misc import dict_to_filename fig_basepath = f"figures/{dict_to_filename(fig_config)}" FIG_BASEPATH = fig_basepath os.makedirs(fig_basepath, exist_ok=True) # if config given, serialize and save that config if fig_config is not None: import json from muutils.json_serialize import json_serialize with open(f"{fig_basepath}/config.json", "w") as f: json.dump( json_serialize(fig_config), f, indent="\t", ) print(f"Figures will be saved to: '{fig_basepath}'") def configure_notebook( *args: typing.Any, seed: int = 42, device: typing.Any = None, # this can be a string, torch.device, or None dark_mode: bool = True, plot_mode: PlottingMode = "inline", fig_output_fmt: str | None = "pdf", fig_numbered_fname: str = "figure-{num}", fig_config: dict | None = None, fig_basepath: str | None = None, close_after_plotshow: bool = False, ) -> "torch.device|None": # type: ignore[name-defined] # noqa: F821 """Shared Jupyter notebook setup steps - Set random seeds and library reproducibility settings - Set device based on availability - Set module reloading before code execution - Set plot formatting - Set plot saving/rendering options # Parameters: - `seed : int` random seed across libraries including torch, numpy, and random (defaults to `42`) (defaults to `42`) - `device : typing.Any` pytorch device to use (defaults to `None`) - `dark_mode : bool` figures in dark mode (defaults to `True`) - `plot_mode : PlottingMode` how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]` (defaults to `"inline"`) - `fig_output_fmt : str | None` format for saving figures (defaults to `"pdf"`) - `fig_numbered_fname : str` format for saving figures with numbers (if they aren't named) (defaults to `"figure-{num}"`) - `fig_config : dict | None` metadata to save with the figures (defaults to `None`) - `fig_basepath : str | None` base path for saving figures (defaults to `None`) - `close_after_plotshow : bool` close figures after showing them (defaults to `False`) # Returns: - `torch.device|None` the device set, if torch is installed """ # set some globals related to plotting setup_plots( plot_mode=plot_mode, fig_output_fmt=fig_output_fmt, fig_numbered_fname=fig_numbered_fname, fig_config=fig_config, fig_basepath=fig_basepath, close_after_plotshow=close_after_plotshow, ) global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }") # Set seeds and other reproducibility-related library options set_reproducibility(seed) # Reload modules before executing user code if IN_JUPYTER: ipython = get_ipython() # pyright: ignore[reportPossiblyUnboundVariable] if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded: # pyright: ignore[reportOptionalMemberAccess] ipython.magic("load_ext autoreload") # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] ipython.magic("autoreload 2") # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] # Specify plotly renderer for vscode if PLOTLY_IMPORTED: pio.renderers.default = "notebook_connected" # pyright: ignore[reportPossiblyUnboundVariable] if dark_mode: pio.templates.default = "plotly_dark" # pyright: ignore[reportPossiblyUnboundVariable] plt.style.use("dark_background") try: # Set device device = get_device(device) return device except ImportError: warnings.warn("Torch not installed. Cannot get/set device.") return None def plotshow( fname: str | None = None, plot_mode: PlottingMode | None = None, fmt: str | None = None, ): """Show the active plot, depending on global configs""" global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE FIG_COUNTER += 1 if plot_mode is None: plot_mode = PLOT_MODE if plot_mode == "save": # get numbered figure name if not given if fname is None: fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER) # save figure assert FIG_BASEPATH is not None universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt) elif plot_mode == "ignore": # do nothing pass elif plot_mode == "inline": # show figure plt.show() elif plot_mode == "widget": # show figure plt.show() else: warnings.warn(f"Invalid plot mode: {plot_mode}") if CLOSE_AFTER_PLOTSHOW: plt.close() ``````{ end_of_file="muutils/nbutils/configure_notebook.py" } ``````{ path="muutils/nbutils/convert_ipynb_to_script.py" } """fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.""" from __future__ import annotations import argparse import json import os from pathlib import Path import sys import typing import warnings from muutils.spinner import SpinnerContext DISABLE_PLOTS: dict[str, list[str]] = { "matplotlib": [ """ # ------------------------------------------------------------ # Disable matplotlib plots, done during processing by `convert_ipynb_to_script.py` import matplotlib.pyplot as plt plt.show = lambda: None # ------------------------------------------------------------ """ ], "circuitsvis": [ """ # ------------------------------------------------------------ # Disable circuitsvis plots, done during processing by `convert_ipynb_to_script.py` from circuitsvis.utils.convert_props import PythonProperty, convert_props from circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local def new_render( react_element_name: str, **kwargs: PythonProperty ) -> RenderedHTML: "return a visualization as raw HTML" local_src = render_local(react_element_name, **kwargs) cdn_src = render_cdn(react_element_name, **kwargs) # return as string instead of RenderedHTML for CI return str(RenderedHTML(local_src, cdn_src)) render = new_render # ------------------------------------------------------------ """ ], "muutils": [ """import muutils.nbutils.configure_notebook as nb_conf nb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore" """ ], } DISABLE_PLOTS_WARNING: list[str] = [ """ # ------------------------------------------------------------ # WARNING: this script is auto-generated by `convert_ipynb_to_script.py` # showing plots has been disabled, so this is presumably in a temp dict for CI or something # so don't modify this code, it will be overwritten! # ------------------------------------------------------------ """.lstrip() ] def disable_plots_in_script(script_lines: list[str]) -> list[str]: """Disable plots in a script by adding cursed things after the import statements""" result_str_TEMP: str = "\n\n".join(script_lines) script_lines_new: list[str] = script_lines if "muutils" in result_str_TEMP: script_lines_new = DISABLE_PLOTS["muutils"] + script_lines_new if "matplotlib" in result_str_TEMP: assert "import matplotlib.pyplot as plt" in result_str_TEMP, ( "matplotlib.pyplot must be imported as plt" ) # find the last import statement involving matplotlib, and the first line that uses plt mpl_last_import_index: int = -1 mpl_first_usage_index: int = -1 for i, line in enumerate(script_lines_new): if "matplotlib" in line and (("import" in line) or ("from" in line)): mpl_last_import_index = i if "configure_notebook" in line: mpl_last_import_index = i if "plt." in line: mpl_first_usage_index = i assert mpl_last_import_index != -1, ( f"matplotlib imports not found! see line {mpl_last_import_index}" ) if mpl_first_usage_index != -1: assert mpl_first_usage_index > mpl_last_import_index, ( f"matplotlib plots created before import! see lines {mpl_first_usage_index}, {mpl_last_import_index}" ) else: warnings.warn( "could not find where matplotlib is used, plot disabling might not work!" ) # insert the cursed things script_lines_new = ( script_lines_new[: mpl_last_import_index + 1] + DISABLE_PLOTS["matplotlib"] + script_lines_new[mpl_last_import_index + 1 :] ) result_str_TEMP = "\n\n".join(script_lines_new) if "circuitsvis" in result_str_TEMP: # find the last import statement involving circuitsvis, and the first line that uses it cirv_last_import_index: int = -1 cirv_first_usage_index: int = -1 for i, line in enumerate(script_lines_new): if "circuitsvis" in line: if (("import" in line) or ("from" in line)) and "circuitsvis" in line: cirv_last_import_index = i else: cirv_first_usage_index = i if "configure_notebook" in line: mpl_last_import_index = i if "render" in line: cirv_first_usage_index = i assert cirv_last_import_index != -1, ( f"circuitsvis imports not found! see line {cirv_last_import_index}" ) if cirv_first_usage_index != -1: assert cirv_first_usage_index > cirv_last_import_index, ( f"circuitsvis plots created before import! see lines {cirv_first_usage_index}, {cirv_last_import_index}" ) else: warnings.warn( "could not find where circuitsvis is used, plot disabling might not work!" ) # insert the cursed things script_lines_new = ( script_lines_new[: cirv_last_import_index + 1] + DISABLE_PLOTS["circuitsvis"] + script_lines_new[cirv_last_import_index + 1 :] ) result_str_TEMP = "\n\n".join(script_lines_new) return script_lines_new def convert_ipynb( notebook: dict, strip_md_cells: bool = False, header_comment: str = r"#%%", disable_plots: bool = False, filter_out_lines: str | typing.Sequence[str] = ( "%", "!", ), # ignore notebook magic commands and shell commands ) -> str: """Convert Jupyter Notebook to a script, doing some basic filtering and formatting. # Arguments - `notebook: dict`: Jupyter Notebook loaded as json. - `strip_md_cells: bool = False`: Remove markdown cells from the output script. - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script. - `disable_plots: bool = False`: Disable plots in the output script. - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks). if a string is passed, it will be split by char and each char will be treated as a separate filter. # Returns - `str`: Converted script. """ if isinstance(filter_out_lines, str): filter_out_lines = tuple(filter_out_lines) filter_out_lines_set: set = set(filter_out_lines) result: list[str] = [] all_cells: list[dict] = notebook["cells"] for cell in all_cells: cell_type: str = cell["cell_type"] if not strip_md_cells and cell_type == "markdown": result.append(f'{header_comment}\n"""\n{"".join(cell["source"])}\n"""') elif cell_type == "code": source: list[str] = cell["source"] if filter_out_lines: source = [ ( f"#{line}" if any( line.startswith(filter_prefix) for filter_prefix in filter_out_lines_set ) else line ) for line in source ] result.append(f"{header_comment}\n{''.join(source)}") if disable_plots: result = disable_plots_in_script(result) result = DISABLE_PLOTS_WARNING + result return "\n\n".join(result) def process_file( in_file: str, out_file: str | None = None, strip_md_cells: bool = False, header_comment: str = r"#%%", disable_plots: bool = False, filter_out_lines: str | typing.Sequence[str] = ("%", "!"), ): print(f"\tProcessing {in_file}...", file=sys.stderr) assert os.path.exists(in_file), f"File {in_file} does not exist." assert os.path.isfile(in_file), f"Path {in_file} is not a file." assert in_file.endswith(".ipynb"), f"File {in_file} is not a Jupyter Notebook." with open(in_file, "r") as file: notebook: dict = json.load(file) try: converted_script: str = convert_ipynb( notebook=notebook, strip_md_cells=strip_md_cells, header_comment=header_comment, disable_plots=disable_plots, filter_out_lines=filter_out_lines, ) except AssertionError as e: print(f"Error converting {in_file}: {e}", file=sys.stderr) raise e if out_file: with open(out_file, "w") as file: file.write(converted_script) else: print(converted_script) def process_dir( input_dir: typing.Union[str, Path], output_dir: typing.Union[str, Path], strip_md_cells: bool = False, header_comment: str = r"#%%", disable_plots: bool = False, filter_out_lines: str | typing.Sequence[str] = ("%", "!"), ): """Convert all Jupyter Notebooks in a directory to scripts. # Arguments - `input_dir: str`: Input directory. - `output_dir: str`: Output directory. - `strip_md_cells: bool = False`: Remove markdown cells from the output script. - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script. - `disable_plots: bool = False`: Disable plots in the output script. - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks). if a string is passed, it will be split by char and each char will be treated as a separate filter. """ assert os.path.exists(input_dir), f"Directory {input_dir} does not exist." assert os.path.isdir(input_dir), f"Path {input_dir} is not a directory." if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) filenames: list[str] = [ fname for fname in os.listdir(input_dir) if fname.endswith(".ipynb") ] assert filenames, f"Directory {input_dir} does not contain any Jupyter Notebooks." n_files: int = len(filenames) print(f"Converting {n_files} notebooks:", file=sys.stderr) with SpinnerContext( spinner_chars="braille", update_interval=0.01, format_string_when_updated=True, output_stream=sys.stderr, ) as spinner: for idx, fname in enumerate(filenames): spinner.update_value(f"\tConverting {idx + 1}/{n_files}: {fname}") in_file: str = os.path.join(input_dir, fname) out_file: str = os.path.join(output_dir, fname.replace(".ipynb", ".py")) with open(in_file, "r", encoding="utf-8") as file_in: notebook: dict = json.load(file_in) try: converted_script: str = convert_ipynb( notebook=notebook, strip_md_cells=strip_md_cells, header_comment=header_comment, disable_plots=disable_plots, filter_out_lines=filter_out_lines, ) except AssertionError as e: spinner.stop() raise Exception(f"Error converting {in_file}") from e with open(out_file, "w", encoding="utf-8") as file_out: file_out.write(converted_script) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Convert Jupyter Notebook to a script with cell separators." ) parser.add_argument( "in_path", type=str, help="Input Jupyter Notebook file (.ipynb) or directory of files.", ) parser.add_argument( "--out-file", type=str, help="Output script file. If not specified, the result will be printed to stdout.", ) parser.add_argument( "--output-dir", type=str, help="Output directory for converted script files." ) parser.add_argument( "--strip-md-cells", action="store_true", help="Remove markdown cells from the output script.", ) parser.add_argument( "--header-comment", type=str, default=r"#%%", help="Comment string to separate cells in the output script.", ) parser.add_argument( "--disable-plots", action="store_true", help="Disable plots in the output script. Useful for testing in CI.", ) parser.add_argument( "--filter-out-lines", type=str, default="%", help="Comment out lines starting with these characters.", ) args = parser.parse_args() if args.output_dir: assert not args.out_file, "Cannot specify both --out_file and --output_dir." process_dir( input_dir=args.in_path, output_dir=args.output_dir, strip_md_cells=args.strip_md_cells, header_comment=args.header_comment, disable_plots=args.disable_plots, filter_out_lines=args.filter_out_lines, ) else: process_file( in_file=args.in_path, out_file=args.out_file, strip_md_cells=args.strip_md_cells, header_comment=args.header_comment, disable_plots=args.disable_plots, filter_out_lines=args.filter_out_lines, ) print("muutils.nbutils.convert_ipynb_to_script.py loaded.") ``````{ end_of_file="muutils/nbutils/convert_ipynb_to_script.py" } ``````{ path="muutils/nbutils/mermaid.py" } """display mermaid.js diagrams in jupyter notebooks by the `mermaid.ink/img` service""" import base64 try: from IPython.display import Image, display except ImportError: import warnings warnings.warn( "IPython.display could not be imported, mermaid will not work", ImportWarning ) def mm(graph: str) -> None: """for plotting mermaid.js diagrams""" graphbytes = graph.encode("ascii") base64_bytes = base64.b64encode(graphbytes) base64_string = base64_bytes.decode("ascii") display(Image(url="https://mermaid.ink/img/" + base64_string)) # pyright: ignore[reportPossiblyUnboundVariable] ``````{ end_of_file="muutils/nbutils/mermaid.py" } ``````{ path="muutils/nbutils/print_tex.py" } """quickly print a sympy expression in latex""" from __future__ import annotations import sympy as sp # type: ignore # pyright: ignore[reportMissingTypeStubs] from IPython.display import Math, display # type: ignore # pyright: ignore[reportUnknownVariableType] def print_tex( expr: sp.Expr, # type: ignore name: str | None = None, plain: bool = False, rendered: bool = True, ): """function for easily rendering a sympy expression in latex""" out: str = sp.latex(expr) # pyright: ignore[reportUnknownVariableType] if name is not None: out = f"{name} = {out}" if plain: print(out) # pyright: ignore[reportUnknownArgumentType] if rendered: display(Math(out)) # pyright: ignore[reportUnusedCallResult] ``````{ end_of_file="muutils/nbutils/print_tex.py" } ``````{ path="muutils/nbutils/run_notebook_tests.py" } """turn a folder of notebooks into scripts, run them, and make sure they work. made to be called as ```bash python -m muutils.nbutils.run_notebook_tests --notebooks-dir --converted-notebooks-temp-dir ``` """ import os import subprocess import sys from pathlib import Path from typing import Optional import warnings from muutils.console_unicode import get_console_safe_str from muutils.spinner import SpinnerContext class NotebookTestError(Exception): pass SUCCESS_STR: str = get_console_safe_str("✅", "[OK]") FAILURE_STR: str = get_console_safe_str("❌", "[!!]") def run_notebook_tests( notebooks_dir: Path, converted_notebooks_temp_dir: Path, CI_output_suffix: str = ".CI-output.txt", run_python_cmd: Optional[str] = None, run_python_cmd_fmt: str = "{python_tool} run python", python_tool: str = "poetry", exit_on_first_fail: bool = False, ): """Run converted Jupyter notebooks as Python scripts and verify they execute successfully. Takes a directory of notebooks and their corresponding converted Python scripts, executes each script, and captures the output. Failures are collected and reported, with optional early exit on first failure. # Parameters: - `notebooks_dir : Path` Directory containing the original .ipynb notebook files - `converted_notebooks_temp_dir : Path` Directory containing the corresponding converted .py files - `CI_output_suffix : str` Suffix to append to output files capturing execution results (defaults to `".CI-output.txt"`) - `run_python_cmd : str | None` Custom command to run Python scripts. Overrides python_tool and run_python_cmd_fmt if provided (defaults to `None`) - `run_python_cmd_fmt : str` Format string for constructing the Python run command (defaults to `"{python_tool} run python"`) - `python_tool : str` Tool used to run Python (e.g. poetry, uv) (defaults to `"poetry"`) - `exit_on_first_fail : bool` Whether to raise exception immediately on first notebook failure (defaults to `False`) # Returns: - `None` # Modifies: - Working directory: Temporarily changes to notebooks_dir during execution - Filesystem: Creates output files with CI_output_suffix for each notebook # Raises: - `NotebookTestError`: If any notebooks fail to execute, or if input directories are invalid - `TypeError`: If run_python_cmd is provided but not a string # Usage: ```python >>> run_notebook_tests( ... notebooks_dir=Path("notebooks"), ... converted_notebooks_temp_dir=Path("temp/converted"), ... python_tool="poetry" ... ) # testing notebooks in 'notebooks' # reading converted notebooks from 'temp/converted' Running 1/2: temp/converted/notebook1.py Output in temp/converted/notebook1.CI-output.txt {SUCCESS_STR} Run completed with return code 0 ``` """ run_python_cmd_: str if run_python_cmd is None: run_python_cmd_ = run_python_cmd_fmt.format(python_tool=python_tool) elif isinstance(run_python_cmd, str): run_python_cmd_ = run_python_cmd warnings.warn( "You have specified a custom run_python_cmd, this will override the `python_tool` parameter and `run_python_cmd_fmt` parameter. This will be removed in a future version.", DeprecationWarning, ) else: raise TypeError( f"run_python_cmd must be a string or None, got {run_python_cmd =}, {type(run_python_cmd) =}" ) original_cwd: Path = Path.cwd() # get paths notebooks_dir = Path(notebooks_dir) converted_notebooks_temp_dir = Path(converted_notebooks_temp_dir) root_relative_to_notebooks: Path = Path(os.path.relpath(".", notebooks_dir)) term_width: int try: term_width = os.get_terminal_size().columns except OSError: term_width = 80 exceptions: dict[str, str] = dict() print(f"# testing notebooks in '{notebooks_dir}'") print( f"# reading converted notebooks from '{converted_notebooks_temp_dir.as_posix()}'" ) try: # check things exist if not notebooks_dir.exists(): raise NotebookTestError(f"Notebooks dir '{notebooks_dir}' does not exist") if not notebooks_dir.is_dir(): raise NotebookTestError( f"Notebooks dir '{notebooks_dir}' is not a directory" ) if not converted_notebooks_temp_dir.exists(): raise NotebookTestError( f"Converted notebooks dir '{converted_notebooks_temp_dir}' does not exist" ) if not converted_notebooks_temp_dir.is_dir(): raise NotebookTestError( f"Converted notebooks dir '{converted_notebooks_temp_dir}' is not a directory" ) notebooks: list[Path] = list(notebooks_dir.glob("*.ipynb")) if not notebooks: raise NotebookTestError(f"No notebooks found in '{notebooks_dir}'") converted_notebooks: list[Path] = list() for nb in notebooks: converted_file: Path = ( converted_notebooks_temp_dir / nb.with_suffix(".py").name ) if not converted_file.exists(): raise NotebookTestError( f"Did not find converted notebook '{converted_file}' for '{nb}'" ) converted_notebooks.append(converted_file) del converted_file # pyright: ignore[reportPossiblyUnboundVariable] # the location of this line is important os.chdir(notebooks_dir) n_notebooks: int = len(converted_notebooks) for idx, file in enumerate(converted_notebooks): # run the file print(f"Running {idx + 1}/{n_notebooks}: {file.as_posix()}") output_file: Path = file.with_suffix(CI_output_suffix) print(f" Output in {output_file.as_posix()}") with SpinnerContext( spinner_chars="braille", update_interval=0.5, format_string="\r {spinner} ({elapsed_time:.2f}s) {message}{value}", ): command: str = f"{run_python_cmd_} {root_relative_to_notebooks / file} > {root_relative_to_notebooks / output_file} 2>&1" process: subprocess.CompletedProcess = subprocess.run( command, shell=True, text=True, env={**os.environ, "PYTHONIOENCODING": "utf-8"}, ) if process.returncode == 0: print( f" {SUCCESS_STR} Run completed with return code {process.returncode}" ) else: print( f" {FAILURE_STR} Run failed with return code {process.returncode}!!! Check {output_file.as_posix()}" ) # print the output of the file to the console if it failed if process.returncode != 0: with open(root_relative_to_notebooks / output_file, "r") as f: file_output: str = f.read() err: str = f"Error in {file}:\n{'-' * term_width}\n{file_output}" exceptions[file.as_posix()] = err if exit_on_first_fail: raise NotebookTestError(err) del process if len(exceptions) > 0: exceptions_str: str = ("\n" + "=" * term_width + "\n").join( list(exceptions.values()) ) raise NotebookTestError( exceptions_str + "=" * term_width + f"\n{FAILURE_STR} {len(exceptions)}/{n_notebooks} notebooks failed:\n{list(exceptions.keys())}" ) except NotebookTestError as e: print("!" * term_width, file=sys.stderr) print(e, file=sys.stderr) print("!" * term_width, file=sys.stderr) raise e finally: # return to original cwd os.chdir(original_cwd) if __name__ == "__main__": import argparse parser: argparse.ArgumentParser = argparse.ArgumentParser() parser.add_argument( "--notebooks-dir", type=str, help="The directory from which to run the notebooks", ) parser.add_argument( "--converted-notebooks-temp-dir", type=str, help="The directory containing the converted notebooks to test", ) parser.add_argument( "--python-tool", type=str, default="poetry", help="The python tool to use to run the notebooks (usually uv or poetry)", ) parser.add_argument( "--run-python-cmd-fmt", type=str, default="{python_tool} run python", help="The command to run python with the python tool. if you don't want to use poetry or uv, you can just set this to 'python'", ) args: argparse.Namespace = parser.parse_args() run_notebook_tests( notebooks_dir=Path(args.notebooks_dir), converted_notebooks_temp_dir=Path(args.converted_notebooks_temp_dir), python_tool=args.python_tool, run_python_cmd_fmt=args.run_python_cmd_fmt, ) ``````{ end_of_file="muutils/nbutils/run_notebook_tests.py" } ``````{ path="muutils/web/__init__.py" } __all__ = [ "bundle_html", ] ``````{ end_of_file="muutils/web/__init__.py" } ``````{ path="muutils/web/bundle_html.py" } """ Inline / bundle external assets (CSS, JS, SVG, PNG) into an HTML document. Default mode uses **zero external dependencies** and a few well-targeted regular expressions. If you install *beautifulsoup4* you can enable the far more robust BS4 mode by passing `InlineConfig(use_bs4=True)`. """ from __future__ import annotations import base64 import re import urllib.request import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Final, Literal # bs4 import deferred to avoid an unconditional dependency. # constants # --------------------------------------------------------------------- AssetExt = Literal[".css", ".js", ".svg", ".png"] DEFAULT_ALLOWED_EXTENSIONS: Final[set[AssetExt]] = {".css", ".js", ".svg", ".png"} DEFAULT_TAG_ATTR: Final[dict[str, str]] = { "link": "href", # "script": "src", # "img": "src", # "use": "xlink:href", # } MIME_BY_EXT: Final[dict[AssetExt, str]] = { ".css": "text/css", ".js": "application/javascript", ".svg": "image/svg+xml", ".png": "image/png", } # Configuration # --------------------------------------------------------------------- @dataclass class InlineConfig: """High-level configuration for the inliner. # Parameters - `allowed_extensions : set[AssetExt]` Extensions that may be inlined. - `tag_attr : dict[str, str]` Mapping *tag -> attribute* that holds the asset reference. - `max_bytes : int` Assets larger than this are ignored. - `local : bool` Allow local filesystem assets. - `remote : bool` Allow remote http/https assets. - `include_filename_comments : bool` Surround every replacement with `` and ``. - `use_bs4 : bool` Parse the document with BeautifulSoup if available. """ allowed_extensions: set[AssetExt] = field( default_factory=lambda: set(DEFAULT_ALLOWED_EXTENSIONS) ) tag_attr: dict[str, str] = field(default_factory=lambda: dict(DEFAULT_TAG_ATTR)) max_bytes: int = 128 * 1024 local: bool = True remote: bool = False include_filename_comments: bool = True use_bs4: bool = False # Low-level helpers # --------------------------------------------------------------------- def _is_remote(url: str) -> bool: """Return *True* if *url* starts with http:// or https://.""" return url.lower().startswith(("http://", "https://")) def _fetch_bytes(src: str, base: Path) -> bytes: """Fetch *src* (local or remote) and return its raw bytes.""" if _is_remote(src): with urllib.request.urlopen(src) as resp: return resp.read() return (base / src).read_bytes() def _decode_text(buf: bytes) -> str: """Decode *buf* as UTF-8, falling back to replacement.""" try: return buf.decode() except UnicodeDecodeError: return buf.decode("utf-8", "replace") # Regex-based implementation (no deps) # --------------------------------------------------------------------- def _apply_indent(html: str, start: int, replacement: str) -> str: """Indent *replacement* to match the line that starts at *start*.""" line_start: int = html.rfind("\n", 0, start) + 1 indent: str = html[line_start:start] return "\n".join(indent + line for line in replacement.splitlines()) def _inline_with_regex(html: str, base: Path, cfg: InlineConfig) -> str: """Inline assets using pure-regex parsing (no third-party libs).""" tag: str attr: str for tag, attr in cfg.tag_attr.items(): pattern: str if tag == "script": pattern = ( rf"]*\s{attr}\s*=\s*['\"]([^'\"]+)['\"][^>]*>\s*" ) elif tag == "link": pattern = rf"]*\s{attr}\s*=\s*['\"]([^'\"]+)['\"][^>]*>" else: # img, use, etc. pattern = rf"<{tag}\b[^>]*\s{attr}\s*=\s*['\"]([^'\"]+)['\"][^>]*>" matches: list[re.Match[str]] = list(re.finditer(pattern, html, re.IGNORECASE)) m: re.Match[str] for m in reversed(matches): raw_src: str = m.group(1) # may contain #fragment clean_src: str = re.split(r"[?#]", raw_src, maxsplit=1)[0] # file path only ext: str = Path(clean_src).suffix.lower() if ext not in cfg.allowed_extensions: continue if _is_remote(clean_src) and not cfg.remote: continue if not _is_remote(clean_src) and not cfg.local: continue try: data: bytes = _fetch_bytes(clean_src, base) except Exception as err: warnings.warn(f"skip '{raw_src}': {err}") continue if len(data) > cfg.max_bytes: continue # build replacement replacement: str if ext in {".css", ".js"}: tag_name: str = "style" if ext == ".css" else "script" replacement = f"<{tag_name}>\n{_decode_text(data)}\n" else: # .svg or .png b64: str = base64.b64encode(data).decode() # TYPING: we check earlier, ext if for sure in MIME_BY_EXT data_uri: str = f"data:{MIME_BY_EXT[ext]};base64,{b64}" # type: ignore[index] replacement = m.group(0).replace(raw_src, data_uri, 1) if cfg.include_filename_comments: replacement = f"\n{replacement}\n" replacement = _apply_indent(html, m.start(), replacement) html = html[: m.start()] + replacement + html[m.end() :] return html # BeautifulSoup-based implementation (optional) # --------------------------------------------------------------------- def _inline_with_bs4(html: str, base: Path, cfg: InlineConfig) -> str: """Inline assets using BeautifulSoup when available.""" try: from bs4 import BeautifulSoup, Comment, Tag except ModuleNotFoundError as exc: # pragma: no cover raise RuntimeError("BeautifulSoup requested but not installed") from exc soup: BeautifulSoup = BeautifulSoup(html, "html.parser") tag: Tag # TYPING: i think soup.find_all() returns a list of Tag objects? mypy thinks it should be PageElement (of which Tag is a subclass) for tag in list(soup.find_all(cfg.tag_attr.keys())): # type: ignore[assignment] attr: str = cfg.tag_attr[tag.name] # TYPING: error: Incompatible types in assignment (expression has type "str | AttributeValueList | None", variable has type "str | None") [assignment] src_full: str | None = tag.get(attr) # type: ignore[assignment] if not src_full: continue clean_src: str = re.split(r"[?#]", src_full, maxsplit=1)[0] ext: str = Path(clean_src).suffix.lower() if ext not in cfg.allowed_extensions: continue if _is_remote(clean_src) and not cfg.remote: continue if not _is_remote(clean_src) and not cfg.local: continue try: data: bytes = _fetch_bytes(clean_src, base) except Exception as err: warnings.warn(f"skip '{src_full}': {err}") continue if len(data) > cfg.max_bytes: continue if ext in {".css", ".js"}: new_tag: Tag = soup.new_tag("style" if ext == ".css" else "script") new_tag.string = _decode_text(data) if cfg.include_filename_comments: tag.insert_before(Comment(f" begin '{src_full}' ")) tag.insert_after(Comment(f" end '{src_full}' ")) tag.replace_with(new_tag) else: # .svg or .png b64: str = base64.b64encode(data).decode() # we are sure ext is in MIME_BY_EXT, so ignore type error tag[attr] = f"data:{MIME_BY_EXT[ext]};base64,{b64}" # type: ignore[index] if cfg.include_filename_comments: tag.insert_before(Comment(f" begin '{src_full}' ")) tag.insert_after(Comment(f" end '{src_full}' ")) return str(soup) # Public API # --------------------------------------------------------------------- def inline_html_assets( html: str, *, base_path: Path, config: InlineConfig | None = None, prettify: bool = False, # kept for API compatibility (ignored in regex mode) ) -> str: """Inline permitted external assets inside *html*. # Parameters - `html : str` Raw HTML text. - `base_path : Path` Directory used to resolve relative asset paths. - `config : InlineConfig | None` Inlining options (see `InlineConfig`). - `prettify : bool` Pretty-print output (only effective in BS4 mode). # Returns - `str` Modified HTML. """ cfg: InlineConfig = config or InlineConfig() if cfg.use_bs4: html_out: str = _inline_with_bs4(html, base_path, cfg) if prettify: # lazy import to avoid unconditional dependency from bs4 import BeautifulSoup # TYPING: .prettify() returns str if no encoding is set html_out = str(BeautifulSoup(html_out, "html.parser").prettify()) else: html_out = _inline_with_regex(html, base_path, cfg) return html_out def inline_html_file( html_path: Path, output_path: Path, base_path: Path | None = None, config: InlineConfig | None = None, prettify: bool = False, ) -> Path: """Read *html_path*, inline its assets, and write the result. # Parameters - `html_path : Path` Source HTML file. - `output_path : Path` Destination path to write the modified HTML. - `base_path : Path | None` Directory used to resolve relative asset paths (defaults to the HTML file's directory). If `None`, uses the directory of *html_path*. (default: `None` -> use `html_path.parent`) - `config : InlineConfig | None` Inlining options. If `None`, uses default configuration. (default: `None` -> use `InlineConfig()`) - `prettify : bool` Pretty-print when `use_bs4=True`. (default: `False`) # Returns - `Path` Path actually written. """ if base_path is None: base_path = html_path.parent html_raw: str = html_path.read_text() html_new: str = inline_html_assets( html_raw, base_path=base_path, config=config, prettify=prettify, ) dest: Path = output_path or html_path dest.write_text(html_new) return dest # CLI # --------------------------------------------------------------------- if __name__ == "__main__": import argparse parser: argparse.ArgumentParser = argparse.ArgumentParser( description="Inline / bundle CSS, JS, SVG, PNG assets. " "Uses regex parsing by default; pass --bs4 to require BeautifulSoup." ) parser.add_argument("html", type=Path, help="input HTML file") parser.add_argument( "-o", "--output", type=Path, help="output file", required=True, ) parser.add_argument( "--source-dir", type=Path, default=None, help="base directory for relative asset paths (defaults to the HTML file's directory)", ) parser.add_argument("--remote", action="store_true", help="allow remote URLs") parser.add_argument("--bs4", action="store_true", help="use BeautifulSoup parser") parser.add_argument( "--prettify", action="store_true", help="pretty-print with BeautifulSoup)" ) parser.add_argument( "--max-bytes", type=int, default=128 * 1024, help="size limit per asset" ) parser.add_argument( "--ext", nargs="+", default=list(DEFAULT_ALLOWED_EXTENSIONS), help="extensions to inline", ) parser.add_argument( "--tag-attr", type=str, default=None, help='override tag->attr map. format: "tag1=attr1,tag2=attr2"', ) parser.add_argument("--no-comments", dest="comments", action="store_false") args: argparse.Namespace = parser.parse_args() tag_attr: dict[str, str] if args.tag_attr: tag_attr = { tag: attr for tag, attr in (item.split("=") for item in args.tag_attr.split(",")) } else: tag_attr = dict(DEFAULT_TAG_ATTR) cfg: InlineConfig = InlineConfig( allowed_extensions=set(args.ext), # type: ignore[arg-type] tag_attr=tag_attr, max_bytes=args.max_bytes, remote=args.remote, include_filename_comments=args.comments, use_bs4=args.bs4, ) inline_html_file( args.html, output_path=args.output, base_path=args.source_dir, config=cfg, prettify=args.prettify, ) ``````{ end_of_file="muutils/web/bundle_html.py" } ``````{ path="muutils/web/html_to_pdf.py" } from pathlib import Path import subprocess from weasyprint import HTML as WeasyHTML # type: ignore[import-untyped] def html_to_pdf(src: Path, dst: Path) -> None: "write HTML file to PDF using WeasyPrint." WeasyHTML(filename=src.as_posix()).write_pdf(dst.as_posix()) def crop(pdf_in: Path, pdf_out: Path, margin_pt: int = 2) -> None: """Run pdfcrop with a tiny safety margin.""" subprocess.run( ["pdfcrop", "--margins", str(margin_pt), pdf_in.as_posix(), pdf_out.as_posix()], check=True, ) def save_html_to_pdf( html: str, pdf_out: Path, pdfcrop: bool = True, margin_pt: int = 2, ) -> None: """Save HTML string to PDF file.""" if isinstance(pdf_out, str): pdf_out = Path(pdf_out) temp_html: Path = pdf_out.with_suffix(".html") temp_html.write_text(html, encoding="utf-8") html_to_pdf(temp_html, pdf_out) if pdfcrop: crop(pdf_out, pdf_out, margin_pt) # Clean up temporary HTML file temp_html.unlink(missing_ok=True) ``````{ end_of_file="muutils/web/html_to_pdf.py" } ``````{ path="muutils/__init__.py" } """ .. include:: ../README.md """ from __future__ import annotations __all__ = [ # submodules (with sub-submodules) "json_serialize", "logger", "math", "misc", "nbutils", "web", # submodules "collect_warnings", "console_unicode", "dbg", "dictmagic", "errormode", "group_equiv", "interval", "jsonlines", "kappa", "mlutils", "parallel", "spinner", "statcounter", "sysinfo", "tensor_info", "tensor_utils", "timeit_fancy", "validate_type", ] ``````{ end_of_file="muutils/__init__.py" } ``````{ path="muutils/collect_warnings.py" } from __future__ import annotations import sys import warnings from collections import Counter from contextlib import AbstractContextManager from types import TracebackType from typing import Any, Literal class CollateWarnings(AbstractContextManager): # type: ignore[type-arg] """Capture every warning issued inside a `with` block and print a collated summary when the block exits. Internally this wraps `warnings.catch_warnings(record=True)` so that all warnings raised in the block are recorded. When the context exits, identical warnings are grouped and (optionally) printed with a user-defined format. # Parameters: - `print_on_exit : bool` Whether to print the summary when the context exits (defaults to `True`) - `fmt : str` Format string used for printing each line of the summary. Available fields are: * `{count}` : number of occurrences * `{filename}` : file where the warning originated * `{lineno}` : line number * `{category}` : warning class name * `{message}` : warning message text (defaults to `"({count}x) {filename}:{lineno} {category}: {message}"`) # Returns: - `CollateWarnings` The context-manager instance. After exit, the attribute `counts` holds a mapping ```python {(filename, lineno, category, message): count} ``` # Usage: ```python >>> import warnings >>> with CollateWarnings() as cw: ... warnings.warn("deprecated", DeprecationWarning) ... warnings.warn("deprecated", DeprecationWarning) ... warnings.warn("other", UserWarning) (2x) /tmp/example.py:42 DeprecationWarning: deprecated (1x) /tmp/example.py:43 UserWarning: other >>> cw.counts {('/tmp/example.py', 42, 'DeprecationWarning', 'deprecated'): 2, ('/tmp/example.py', 43, 'UserWarning', 'other'): 1} ``` """ _active: bool _catcher: Any _records: list[warnings.WarningMessage] counts: Counter[ tuple[ str, # filename int, # lineno str, # category name str, # message ] ] print_on_exit: bool fmt: str def __init__( self, print_on_exit: bool = True, fmt: str = "({count}x) {filename}:{lineno} {category}: {message}", ) -> None: self.print_on_exit = print_on_exit self.fmt = fmt self._active = False self._records = [] self.counts = Counter() def __enter__(self) -> CollateWarnings: if self._active: raise RuntimeError("CollateWarnings cannot be re-entered") self._active = True self._catcher = warnings.catch_warnings(record=True) self._records = self._catcher.__enter__() warnings.simplefilter("always") # capture every warning return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> Literal[False]: if not self._active: raise RuntimeError("CollateWarnings exited twice") self._active = False # stop capturing self._catcher.__exit__(exc_type, exc_val, exc_tb) # collate self.counts = Counter( ( rec.filename, rec.lineno, rec.category.__name__, str(rec.message), ) for rec in self._records ) if self.print_on_exit: for (filename, lineno, category, message), count in self.counts.items(): print( self.fmt.format( count=count, filename=filename, lineno=lineno, category=category, message=message, ), file=sys.stderr, ) # propagate any exception from the with-block return False ``````{ end_of_file="muutils/collect_warnings.py" } ``````{ path="muutils/console_unicode.py" } import locale def get_console_safe_str( default: str, fallback: str, ) -> str: """Determine a console-safe string based on the preferred encoding. This function attempts to encode a given `default` string using the system's preferred encoding. If encoding is successful, it returns the `default` string; otherwise, it returns a `fallback` string. # Parameters: - `default : str` The primary string intended for use, to be tested against the system's preferred encoding. - `fallback : str` The alternative string to be used if `default` cannot be encoded in the system's preferred encoding. # Returns: - `str` Either `default` or `fallback` based on whether `default` can be encoded safely. # Usage: ```python >>> get_console_safe_str("café", "cafe") "café" # This result may vary based on the system's preferred encoding. ``` """ try: _ = default.encode(locale.getpreferredencoding()) return default except UnicodeEncodeError: return fallback ``````{ end_of_file="muutils/console_unicode.py" } ``````{ path="muutils/dbg.py" } """ this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has been significantly modified licensed under MIT: Copyright (c) 2019 Tyler Wince Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import annotations import inspect import sys import typing from pathlib import Path import re # type defs _ExpType = typing.TypeVar("_ExpType") _ExpType_dict = typing.TypeVar( "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any] ) _ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any]) # TypedDict definitions for configuration dictionaries class DBGDictDefaultsType(typing.TypedDict): key_types: bool val_types: bool max_len: int indent: str max_depth: int class DBGListDefaultsType(typing.TypedDict): max_len: int summary_show_types: bool class DBGTensorArraySummaryDefaultsType(typing.TypedDict): fmt: typing.Literal["unicode", "latex", "ascii"] precision: int stats: bool shape: bool dtype: bool device: bool requires_grad: bool sparkline: bool sparkline_bins: int sparkline_logy: typing.Union[None, bool] colored: bool eq_char: str # Sentinel type for no expression passed class _NoExpPassedSentinel: """Unique sentinel type used to indicate that no expression was passed.""" pass _NoExpPassed = _NoExpPassedSentinel() # global variables _CWD: Path = Path.cwd().absolute() _COUNTER: int = 0 # configuration PATH_MODE: typing.Literal["relative", "absolute"] = "relative" DEFAULT_VAL_JOINER: str = " = " # path processing def _process_path(path: Path) -> str: path_abs: Path = path.absolute() fname: Path if PATH_MODE == "absolute": fname = path_abs elif PATH_MODE == "relative": try: # if it's inside the cwd, print the relative path fname = path.relative_to(_CWD) except ValueError: # if its not in the subpath, use the absolute path fname = path_abs else: raise ValueError("PATH_MODE must be either 'relative' or 'absolute") return fname.as_posix() # actual dbg function @typing.overload def dbg() -> _NoExpPassedSentinel: ... @typing.overload def dbg( exp: _NoExpPassedSentinel, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, val_joiner: str = DEFAULT_VAL_JOINER, ) -> _NoExpPassedSentinel: ... @typing.overload def dbg( exp: _ExpType, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, val_joiner: str = DEFAULT_VAL_JOINER, ) -> _ExpType: ... def dbg( exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, val_joiner: str = DEFAULT_VAL_JOINER, ) -> typing.Union[_ExpType, _NoExpPassedSentinel]: """Call dbg with any variable or expression. Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to: from muutils.dbg import dbg a = 2 b = 5 dbg(a+b) def square(x: int) -> int: return x * x dbg(square(a)) """ global _COUNTER # get the context line_exp: str = "unknown" current_file: str = "unknown" dbg_frame: typing.Optional[inspect.FrameInfo] = None for frame in inspect.stack(): if frame.code_context is None: continue line: str = frame.code_context[0] if "dbg" in line: current_file = _process_path(Path(frame.filename)) dbg_frame = frame start: int = line.find("(") + 1 end: int = line.rfind(")") if end == -1: end = len(line) line_exp = line[start:end] break fname: str = "unknown" if current_file.startswith("/tmp/ipykernel_"): stack: list[inspect.FrameInfo] = inspect.stack() filtered_functions: list[str] = [] # this loop will find, in this order: # - the dbg function call # - the functions we care about displaying # - `` # - a bunch of jupyter internals we don't care about for frame_info in stack: if _process_path(Path(frame_info.filename)) != current_file: continue if frame_info.function == "": break if frame_info.function.startswith("dbg"): continue filtered_functions.append(frame_info.function) if dbg_frame is not None: filtered_functions.append(f":{dbg_frame.lineno}") else: filtered_functions.append(current_file) filtered_functions.reverse() fname = " -> ".join(filtered_functions) elif dbg_frame is not None: fname = f"{current_file}:{dbg_frame.lineno}" # assemble the message msg: str if exp is _NoExpPassed: # if no expression is passed, just show location and counter value msg = f"[ {fname} ] " _COUNTER += 1 else: # if expression passed, format its value and show location, expr, and value exp_val: str = formatter(exp) if formatter else repr(exp) msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" # print the message print( msg, file=sys.stderr, ) # return the expression itself return exp # formatted `dbg_*` functions with their helpers DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = { "fmt": "unicode", "precision": 2, "stats": True, "shape": True, "dtype": True, "device": True, "requires_grad": True, "sparkline": True, "sparkline_bins": 7, "sparkline_logy": None, # None means auto-detect "colored": True, "eq_char": "=", } DBG_TENSOR_VAL_JOINER: str = ": " def tensor_info(tensor: typing.Any) -> str: from muutils.tensor_info import array_summary # TODO: explicitly pass args to avoid type: ignore (mypy can't match overloads with **TypedDict spread) return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload] DBG_DICT_DEFAULTS: DBGDictDefaultsType = { "key_types": True, "val_types": True, "max_len": 32, "indent": " ", "max_depth": 3, } DBG_LIST_DEFAULTS: DBGListDefaultsType = { "max_len": 16, "summary_show_types": True, } def list_info( lst: typing.List[typing.Any], ) -> str: len_l: int = len(lst) output: str if len_l > DBG_LIST_DEFAULTS["max_len"]: output = f"", "", } def dict_info( d: typing.Dict[typing.Any, typing.Any], depth: int = 0, ) -> str: len_d: int = len(d) indent: str = DBG_DICT_DEFAULTS["indent"] # summary line output: str = f"{indent * depth} 0: key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys()) key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}" output += f", key_types={key_types_str}" if DBG_DICT_DEFAULTS["val_types"] and len_d > 0: val_types: typing.Set[str] = set(type(v).__name__ for v in d.values()) val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}" output += f", val_types={val_types_str}" output += ">" # keys/values if not to deep and not too many if depth < DBG_DICT_DEFAULTS["max_depth"]: if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: for k, v in d.items(): key_str: str = repr(k) if not isinstance(k, str) else k val_str: str val_type_str: str = str(type(v)) if isinstance(v, dict): val_str = dict_info(v, depth + 1) elif val_type_str in TENSOR_STR_TYPES: val_str = tensor_info(v) elif isinstance(v, list): val_str = list_info(v) else: val_str = repr(v) output += ( f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}" ) return output def info_auto( obj: typing.Any, ) -> str: """Automatically format an object for debugging.""" if isinstance(obj, dict): return dict_info(obj) elif isinstance(obj, list): return list_info(obj) elif str(type(obj)) in TENSOR_STR_TYPES: return tensor_info(obj) else: return repr(obj) def dbg_tensor( tensor: _ExpType, # numpy array or torch tensor ) -> _ExpType: """dbg function for tensors, using tensor_info formatter.""" return dbg( tensor, formatter=tensor_info, val_joiner=DBG_TENSOR_VAL_JOINER, ) def dbg_dict( d: _ExpType_dict, ) -> _ExpType_dict: """dbg function for dictionaries, using dict_info formatter.""" return dbg( d, formatter=dict_info, val_joiner=DBG_TENSOR_VAL_JOINER, ) def dbg_auto( obj: _ExpType, ) -> _ExpType: """dbg function for automatic formatting based on type.""" return dbg( obj, formatter=info_auto, val_joiner=DBG_TENSOR_VAL_JOINER, ) def _normalize_for_loose(text: str) -> str: """Normalize text for loose matching by replacing non-alphanumeric chars with spaces.""" normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text) return " ".join(normalized.split()) def _compile_pattern( pattern: str | re.Pattern[str], *, cased: bool = False, loose: bool = False, ) -> re.Pattern[str]: """Compile pattern with appropriate flags for case sensitivity and loose matching.""" if isinstance(pattern, re.Pattern): return pattern # Start with no flags for case-insensitive default flags: int = 0 if not cased: flags |= re.IGNORECASE if loose: pattern = _normalize_for_loose(pattern) return re.compile(pattern, flags) def grep_repr( obj: typing.Any, pattern: str | re.Pattern[str], *, char_context: int | None = 20, line_context: int | None = None, before_context: int = 0, after_context: int = 0, context: int | None = None, max_count: int | None = None, cased: bool = False, loose: bool = False, line_numbers: bool = False, highlight: bool = True, color: str = "31", separator: str = "--", quiet: bool = False, ) -> typing.List[str] | None: """grep-like search on ``repr(obj)`` with improved grep-style options. By default, string patterns are case-insensitive. Pre-compiled regex patterns use their own flags. Parameters: - obj: Object to search (its repr() string is scanned) - pattern: Regular expression pattern (string or pre-compiled) - char_context: Characters of context before/after each match (default: 20) - line_context: Lines of context before/after; overrides char_context - before_context: Lines of context before match (like grep -B) - after_context: Lines of context after match (like grep -A) - context: Lines of context before AND after (like grep -C) - max_count: Stop after this many matches - cased: Force case-sensitive search for string patterns - loose: Normalize spaces/punctuation for flexible matching - line_numbers: Show line numbers in output - highlight: Wrap matches with ANSI color codes - color: ANSI color code (default: "31" for red) - separator: Separator between multiple matches - quiet: Return results instead of printing Returns: - None if quiet=False (prints to stdout) - List[str] if quiet=True (returns formatted output lines) """ # Handle context parameter shortcuts if context is not None: before_context = after_context = context # Prepare text and pattern text: str = repr(obj) if loose: text = _normalize_for_loose(text) regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose) def _color_match(segment: str) -> str: if not highlight: return segment return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment) output_lines: list[str] = [] match_count: int = 0 # Determine if we're using line-based context using_line_context = ( line_context is not None or before_context > 0 or after_context > 0 ) if using_line_context: lines: list[str] = text.splitlines() line_starts: list[int] = [] pos: int = 0 for line in lines: line_starts.append(pos) pos += len(line) + 1 # +1 for newline processed_lines: set[int] = set() for match in regex.finditer(text): if max_count is not None and match_count >= max_count: break # Find which line contains this match match_line = max( i for i, start in enumerate(line_starts) if start <= match.start() ) # Calculate context range ctx_before: int ctx_after: int if line_context is not None: ctx_before = ctx_after = line_context else: ctx_before, ctx_after = before_context, after_context start_line: int = max(0, match_line - ctx_before) end_line: int = min(len(lines), match_line + ctx_after + 1) # Avoid duplicate output for overlapping contexts line_range: set[int] = set(range(start_line, end_line)) if line_range & processed_lines: continue processed_lines.update(line_range) # Format the context block context_lines: list[str] = [] for i in range(start_line, end_line): line_text = lines[i] if line_numbers: line_prefix = f"{i + 1}:" line_text = f"{line_prefix}{line_text}" context_lines.append(_color_match(line_text)) if output_lines and separator: output_lines.append(separator) output_lines.extend(context_lines) match_count += 1 else: # Character-based context ctx: int = 0 if char_context is None else char_context for match in regex.finditer(text): if max_count is not None and match_count >= max_count: break start: int = max(0, match.start() - ctx) end: int = min(len(text), match.end() + ctx) snippet: str = text[start:end] if output_lines and separator: output_lines.append(separator) output_lines.append(_color_match(snippet)) match_count += 1 if quiet: return output_lines else: for line in output_lines: print(line) return None ``````{ end_of_file="muutils/dbg.py" } ``````{ path="muutils/dictmagic.py" } """making working with dictionaries easier - `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument - various methods for working wit dotlist-nested dicts, converting to and from them - `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges - `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes - `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict """ from __future__ import annotations import typing import warnings from collections import defaultdict from typing import ( Any, Callable, Generic, Hashable, Iterable, Literal, Optional, TypeVar, Union, ) from muutils.errormode import ErrorMode _KT = TypeVar("_KT") _VT = TypeVar("_VT") class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" def __init__( self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any ) -> None: if args: raise TypeError( f"DefaulterDict does not support positional arguments: *args = {args}" ) super().__init__(**kwargs) self.default_factory: Callable[[_KT], _VT] = default_factory def __getitem__(self, k: _KT) -> _VT: if k in self: return dict.__getitem__(self, k) else: v: _VT = self.default_factory(k) dict.__setitem__(self, k, v) return v def _recursive_defaultdict_ctor() -> defaultdict: return defaultdict(_recursive_defaultdict_ctor) def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" return { key: ( defaultdict_to_dict_recursive(value) if isinstance(value, (defaultdict, DefaulterDict)) else value ) for key, value in dd.items() } def dotlist_to_nested_dict( dot_dict: typing.Dict[str, Any], sep: str = "." ) -> typing.Dict[str, Any]: """Convert a dict with dot-separated keys to a nested dict Example: >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} """ nested_dict: defaultdict = _recursive_defaultdict_ctor() for key, value in dot_dict.items(): if not isinstance(key, str): raise TypeError(f"key must be a string, got {type(key)}") keys: list[str] = key.split(sep) current: defaultdict = nested_dict # iterate over the keys except the last one for sub_key in keys[:-1]: current = current[sub_key] current[keys[-1]] = value return defaultdict_to_dict_recursive(nested_dict) def nested_dict_to_dotlist( nested_dict: typing.Dict[str, Any], sep: str = ".", allow_lists: bool = False, ) -> dict[str, Any]: def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: items: dict = dict() new_key: str if isinstance(current, dict): # dict case if not current and parent_key: items[parent_key] = current else: for k, v in current.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k items.update(_recurse(v, new_key)) elif allow_lists and isinstance(current, list): # list case for i, item in enumerate(current): new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) items.update(_recurse(item, new_key)) else: # anything else (write value) items[parent_key] = current return items return _recurse(nested_dict) def update_with_nested_dict( original: dict[str, Any], update: dict[str, Any], ) -> dict[str, Any]: """Update a dict with a nested dict Example: >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}}) {'a': {'b': 2}, 'c': -1} # Arguments - `original: dict[str, Any]` the dict to update (will be modified in-place) - `update: dict[str, Any]` the dict to update with # Returns - `dict` the updated dict """ for key, value in update.items(): if key in original: if isinstance(original[key], dict) and isinstance(value, dict): update_with_nested_dict(original[key], value) else: original[key] = value else: original[key] = value return original def kwargs_to_nested_dict( kwargs_dict: dict[str, Any], sep: str = ".", strip_prefix: Optional[str] = None, when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN, transform_key: Optional[Callable[[str], str]] = None, ) -> dict[str, Any]: """given kwargs from fire, convert them to a nested dict if strip_prefix is not None, then all keys must start with the prefix. by default, will warn if an unknown prefix is found, but can be set to raise an error or ignore it: `when_unknown_prefix: ErrorMode` Example: ```python def main(**kwargs): print(kwargs_to_nested_dict(kwargs)) fire.Fire(main) ``` running the above script will give: ```bash $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3 {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} ``` # Arguments - `kwargs_dict: dict[str, Any]` the kwargs dict to convert - `sep: str = "."` the separator to use for nested keys - `strip_prefix: Optional[str] = None` if not None, then all keys must start with this prefix - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` what to do when an unknown prefix is found - `transform_key: Callable[[str], str] | None = None` a function to apply to each key before adding it to the dict (applied after stripping the prefix) """ when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix) filtered_kwargs: dict[str, Any] = dict() for key, value in kwargs_dict.items(): if strip_prefix is not None: if not key.startswith(strip_prefix): when_unknown_prefix_.process( f"key '{key}' does not start with '{strip_prefix}'", except_cls=ValueError, ) else: key = key[len(strip_prefix) :] if transform_key is not None: key = transform_key(key) filtered_kwargs[key] = value return dotlist_to_nested_dict(filtered_kwargs, sep=sep) def is_numeric_consecutive(lst: list[str]) -> bool: """Check if the list of keys is numeric and consecutive.""" try: numbers: list[int] = [int(x) for x in lst] return sorted(numbers) == list(range(min(numbers), max(numbers) + 1)) except ValueError: return False def condense_nested_dicts_numeric_keys( data: dict[str, Any], ) -> dict[str, Any]: """condense a nested dict, by condensing numeric keys with matching values to ranges # Examples: ```python >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2}) {'[1-3]': 1, '[4-6]': 2} >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'}) {"1": {"[1-2]": "a"}, "2": "b"} ``` """ if not isinstance(data, dict): return data # Process each sub-dictionary for key, value in list(data.items()): data[key] = condense_nested_dicts_numeric_keys(value) # Find all numeric, consecutive keys if is_numeric_consecutive(list(data.keys())): keys: list[str] = sorted(data.keys(), key=lambda x: int(x)) else: return data # output dict condensed_data: dict[str, Any] = {} # Identify ranges of identical values and condense i: int = 0 while i < len(keys): j: int = i while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]: j += 1 if j > i: # Found consecutive keys with identical values condensed_key: str = f"[{keys[i]}-{keys[j]}]" condensed_data[condensed_key] = data[keys[i]] i = j + 1 else: condensed_data[keys[i]] = data[keys[i]] i += 1 return condensed_data def condense_nested_dicts_matching_values( data: dict[str, Any], val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, ) -> dict[str, Any]: """condense a nested dict, by condensing keys with matching values # Examples: TODO # Parameters: - `data : dict[str, Any]` data to process - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` a function to apply to each value before adding it to the dict (if it's not hashable) (defaults to `None`) """ if isinstance(data, dict): data = { key: condense_nested_dicts_matching_values( value, val_condense_fallback_mapping ) for key, value in data.items() } else: return data # Find all identical values and condense by stitching together keys values_grouped: defaultdict[Any, list[str]] = defaultdict(list) data_persist: dict[str, Any] = dict() for key, value in data.items(): if not isinstance(value, dict): try: values_grouped[value].append(key) except TypeError: # If the value is unhashable, use a fallback mapping to find a hashable representation if val_condense_fallback_mapping is not None: values_grouped[val_condense_fallback_mapping(value)].append(key) else: data_persist[key] = value else: data_persist[key] = value condensed_data = data_persist for value, keys in values_grouped.items(): if len(keys) > 1: merged_key = f"[{', '.join(keys)}]" # Choose an appropriate method to represent merged keys condensed_data[merged_key] = value else: condensed_data[keys[0]] = value return condensed_data def condense_nested_dicts( data: dict[str, Any], condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, ) -> dict[str, Any]: """condense a nested dict, by condensing numeric or matching keys with matching values to ranges combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes it's not reversible because types are lost to make the printing pretty # Parameters: - `data : dict[str, Any]` data to process - `condense_numeric_keys : bool` whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") (defaults to `True`) - `condense_matching_values : bool` whether to condense keys with matching values (defaults to `True`) - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` a function to apply to each value before adding it to the dict (if it's not hashable) (defaults to `None`) """ condensed_data: dict = data if condense_numeric_keys: condensed_data = condense_nested_dicts_numeric_keys(condensed_data) if condense_matching_values: condensed_data = condense_nested_dicts_matching_values( condensed_data, val_condense_fallback_mapping ) return condensed_data def tuple_dims_replace( t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None ) -> tuple[Union[int, str], ...]: if dims_names_map is None: return t else: return tuple(dims_names_map.get(x, x) for x in t) TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821 TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821 TensorDictFormats = Literal["dict", "json", "yaml", "yml"] def _default_shapes_convert(x: tuple) -> str: return str(x).replace('"', "").replace("'", "") def condense_tensor_dict( data: TensorDict | TensorIterable, fmt: TensorDictFormats = "dict", *args: Any, shapes_convert: Callable[ [tuple[Union[int, str], ...]], Any ] = _default_shapes_convert, drop_batch_dims: int = 0, sep: str = ".", dims_names_map: Optional[dict[int, str]] = None, condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, return_format: Optional[TensorDictFormats] = None, ) -> Union[str, dict[str, str | tuple[int, ...]]]: """Convert a dictionary of tensors to a dictionary of shapes. by default, values are converted to strings of their shapes (for nice printing). If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. # Parameters: - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) - `fmt : TensorDictFormats` format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. (defaults to `'dict'`) - `shapes_convert : Callable[[tuple], Any]` conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`) - `drop_batch_dims : int` number of leading dimensions to drop from the shape (defaults to `0`) - `sep : str` separator to use for nested keys (defaults to `'.'`) - `dims_names_map : dict[int, str] | None` convert certain dimension values in shape. not perfect, can be buggy (defaults to `None`) - `condense_numeric_keys : bool` whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts` (defaults to `True`) - `condense_matching_values : bool` whether to condense keys with matching values, passed on to `condense_nested_dicts` (defaults to `True`) - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts` (defaults to `None`) - `return_format : TensorDictFormats | None` legacy alias for `fmt` kwarg # Returns: - `str|dict[str, str|tuple[int, ...]]` dict if `return_format='dict'`, a string for `json` or `yaml` output # Examples: ```python >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2") >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml')) ``` ```yaml embed: W_E: (50257, 768) pos_embed: W_pos: (1024, 768) blocks: '[0-11]': attn: '[W_Q, W_K, W_V]': (12, 768, 64) W_O: (12, 64, 768) '[b_Q, b_K, b_V]': (12, 64) b_O: (768,) mlp: W_in: (768, 3072) b_in: (3072,) W_out: (3072, 768) b_out: (768,) unembed: W_U: (768, 50257) b_U: (50257,) ``` # Raises: - `ValueError` : if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed """ # handle arg processing: # ---------------------------------------------------------------------- # make all args except data and format keyword-only assert len(args) == 0, f"unexpected positional args: {args}" # handle legacy return_format if return_format is not None: warnings.warn( "return_format is deprecated, use fmt instead", DeprecationWarning, ) fmt = return_format # identity function for shapes_convert if not provided if shapes_convert is None: shapes_convert = lambda x: x # noqa: E731 # convert to iterable data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore ) # get shapes data_shapes: dict[str, Union[str, tuple[int, ...]]] = { # pyright: ignore[reportAssignmentType] k: shapes_convert( tuple_dims_replace( tuple(v.shape)[drop_batch_dims:], dims_names_map, ) ) for k, v in data_items } # nest the dict data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) # condense the nested dict data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( data=data_nested, condense_numeric_keys=condense_numeric_keys, condense_matching_values=condense_matching_values, val_condense_fallback_mapping=val_condense_fallback_mapping, ) # return in the specified format fmt_lower: str = fmt.lower() if fmt_lower == "dict": return data_condensed elif fmt_lower == "json": import json return json.dumps(data_condensed, indent=2) elif fmt_lower in ["yaml", "yml"]: try: import yaml # type: ignore[import-untyped] return yaml.dump(data_condensed, sort_keys=False) except ImportError as e: raise ValueError("PyYAML is required for YAML output") from e else: raise ValueError(f"Invalid return format: {fmt}") ``````{ end_of_file="muutils/dictmagic.py" } ``````{ path="muutils/errormode.py" } """provides `ErrorMode` enum for handling errors consistently pass an `error_mode: ErrorMode` to a function to specify how to handle a certain kind of exception. That function then instead of `raise`ing or `warnings.warn`ing, calls `error_mode.process` with the message and the exception. you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning. """ from __future__ import annotations import sys import typing import types import warnings from enum import Enum class WarningFunc(typing.Protocol): def __call__( self, msg: str, category: typing.Type[Warning], source: typing.Any = None, ) -> None: ... LoggingFunc = typing.Callable[[str], None] GLOBAL_WARN_FUNC: WarningFunc = warnings.warn # type: ignore[assignment] GLOBAL_LOG_FUNC: LoggingFunc = print def custom_showwarning( message: Warning | str, category: typing.Type[Warning] | None = None, filename: str | None = None, lineno: int | None = None, file: typing.Optional[typing.TextIO] = None, line: typing.Optional[str] = None, ) -> None: if category is None: category = UserWarning # Get the frame where process() was called # Adjusted to account for the extra function call frame: types.FrameType = sys._getframe(2) # get globals and traceback traceback: types.TracebackType = types.TracebackType( None, frame, frame.f_lasti, frame.f_lineno ) _globals: dict[str, typing.Any] = frame.f_globals # init the new warning and add the traceback if isinstance(message, str): message = category(message) message = message.with_traceback(traceback) # Call the original showwarning function warnings.warn_explicit( message=message, category=category, # filename arg if it's passed, otherwise use the frame's filename filename=frame.f_code.co_filename, lineno=frame.f_lineno, module=frame.f_globals.get("__name__", "__main__"), registry=_globals.setdefault("__warningregistry__", {}), module_globals=_globals, ) # warnings._showwarning_orig( # message, # category, # frame.f_code.co_filename, # frame.f_lineno, # file, # line, # ) class ErrorMode(Enum): """Enum for handling errors consistently pass one of the instances of this enum to a function to specify how to handle a certain kind of exception. That function then instead of `raise`ing or `warnings.warn`ing, calls `error_mode.process` with the message and the exception. """ EXCEPT = "except" WARN = "warn" LOG = "log" IGNORE = "ignore" def process( self, msg: str, except_cls: typing.Type[Exception] = ValueError, warn_cls: typing.Type[Warning] = UserWarning, except_from: typing.Optional[Exception] = None, warn_func: WarningFunc | None = None, log_func: LoggingFunc | None = None, ): """process an exception or warning according to the error mode # Parameters: - `msg : str` message to pass to `except_cls` or `warn_func` - `except_cls : typing.Type[Exception]` exception class to raise, must be a subclass of `Exception` (defaults to `ValueError`) - `warn_cls : typing.Type[Warning]` warning class to use, must be a subclass of `Warning` (defaults to `UserWarning`) - `except_from : typing.Optional[Exception]` will `raise except_cls(msg) from except_from` if not `None` (defaults to `None`) - `warn_func : WarningFunc | None` function to use for warnings, must have the signature `warn_func(msg: str, category: typing.Type[Warning], source: typing.Any = None) -> None` (defaults to `None`) - `log_func : LoggingFunc | None` function to use for logging, must have the signature `log_func(msg: str) -> None` (defaults to `None`) # Raises: - `except_cls` : _description_ - `except_cls` : _description_ - `ValueError` : _description_ """ if self is ErrorMode.EXCEPT: # except, possibly with a chained exception frame: types.FrameType = sys._getframe(1) traceback: types.TracebackType = types.TracebackType( None, frame, frame.f_lasti, frame.f_lineno ) # Attach the new traceback to the exception and raise it without the internal call stack if except_from is not None: raise except_cls(msg).with_traceback(traceback) from except_from else: raise except_cls(msg).with_traceback(traceback) elif self is ErrorMode.WARN: # get global warn function if not passed if warn_func is None: warn_func = GLOBAL_WARN_FUNC # augment warning message with source if except_from is not None: msg = f"{msg}\n\tSource of warning: {except_from}" if warn_func == warnings.warn: custom_showwarning(msg, category=warn_cls) else: # Use the provided warn_func as-is warn_func(msg, category=warn_cls) elif self is ErrorMode.LOG: # get global log function if not passed if log_func is None: log_func = GLOBAL_LOG_FUNC # log log_func(msg) elif self is ErrorMode.IGNORE: # do nothing pass else: raise ValueError(f"Unknown error mode {self}") @classmethod def from_any( cls, mode: "str|ErrorMode", allow_aliases: bool = True, allow_prefix: bool = True, ) -> ErrorMode: """initialize an `ErrorMode` from a string or an `ErrorMode` instance""" if isinstance(mode, ErrorMode): return mode elif isinstance(mode, str): # strip mode = mode.strip() # remove prefix if allow_prefix and mode.startswith("ErrorMode."): mode = mode[len("ErrorMode.") :] # lowercase and strip again mode = mode.strip().lower() if not allow_aliases: # try without aliases try: return ErrorMode(mode) except ValueError as e: raise KeyError(f"Unknown error mode {mode = }") from e else: # look up in aliases map return ERROR_MODE_ALIASES[mode] else: raise TypeError( f"Expected {ErrorMode = } or str, got {type(mode) = } {mode = }" ) def __str__(self) -> str: return f"ErrorMode.{self.value.capitalize()}" def __repr__(self) -> str: return str(self) def serialize(self) -> str: return str(self) @classmethod def load(cls, data: str) -> ErrorMode: return cls.from_any( data, allow_aliases=False, allow_prefix=True, ) ERROR_MODE_ALIASES: dict[str, ErrorMode] = { # base "except": ErrorMode.EXCEPT, "warn": ErrorMode.WARN, "log": ErrorMode.LOG, "ignore": ErrorMode.IGNORE, # except "e": ErrorMode.EXCEPT, "error": ErrorMode.EXCEPT, "err": ErrorMode.EXCEPT, "raise": ErrorMode.EXCEPT, # warn "w": ErrorMode.WARN, "warning": ErrorMode.WARN, # log "l": ErrorMode.LOG, "print": ErrorMode.LOG, "output": ErrorMode.LOG, "show": ErrorMode.LOG, "display": ErrorMode.LOG, # ignore "i": ErrorMode.IGNORE, "silent": ErrorMode.IGNORE, "quiet": ErrorMode.IGNORE, "nothing": ErrorMode.IGNORE, } "map of string aliases to `ErrorMode` instances" ``````{ end_of_file="muutils/errormode.py" } ``````{ path="muutils/group_equiv.py" } "group items by assuming that `eq_func` defines an equivalence relation" from __future__ import annotations from itertools import chain from typing import Callable, Sequence, TypeVar T = TypeVar("T") def group_by_equivalence( items_in: Sequence[T], eq_func: Callable[[T, T], bool], ) -> list[list[T]]: """group items by assuming that `eq_func` implies an equivalence relation but might not be transitive so, if f(a,b) and f(b,c) then f(a,c) might be false, but we still want to put [a,b,c] in the same class note that lists are used to avoid the need for hashable items, and to allow for duplicates # Arguments - `items_in: Sequence[T]` the items to group - `eq_func: Callable[[T, T], bool]` a function that returns true if two items are equivalent. need not be transitive """ items: list[T] = list(items_in) items.reverse() output: list[list[T]] = list() while items: x: T = items.pop() # try to add to an existing class found_classes: list[int] = list() for i, c in enumerate(output): if any(eq_func(x, y) for y in c): found_classes.append(i) # if one class found, add to it if len(found_classes) == 1: output[found_classes.pop()].append(x) elif len(found_classes) > 1: # if multiple classes found, merge the classes # first sort the ones to be merged output_new: list[list[T]] = list() to_merge: list[list[T]] = list() for i, c in enumerate(output): if i in found_classes: to_merge.append(c) else: output_new.append(c) # then merge them back in, along with the element `x` merged: list[T] = list(chain.from_iterable(to_merge)) merged.append(x) output_new.append(merged) output = output_new # if no class found, make a new one else: output.append([x]) return output ``````{ end_of_file="muutils/group_equiv.py" } ``````{ path="muutils/interval.py" } "represents a mathematical `Interval` over the real numbers" from __future__ import annotations import math from typing import Optional, Sequence, Union, Any from muutils.misc import str_to_numeric _EPSILON: float = 1e-10 Number = Union[float, int] # TODO: make this also work with decimals, fractions, numpy types, etc. # except we must somehow avoid importing them? idk _EMPTY_INTERVAL_ARGS: tuple[Number, Number, bool, bool, set[Number]] = ( math.nan, math.nan, False, False, set(), ) class Interval: """ Represents a mathematical interval, open by default. The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks. Examples: >>> i1 = Interval(1, 5) # Default open interval (1, 5) >>> 3 in i1 True >>> 1 in i1 False >>> i2 = Interval([1, 5]) # Closed interval [1, 5] >>> 1 in i2 True >>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5) >>> str(i3) '[1, 5)' >>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5] >>> i5 = OpenInterval(1, 5) # Open interval (1, 5) """ def __init__( self, *args: Union[Sequence[Number], Number], is_closed: Optional[bool] = None, closed_L: Optional[bool] = None, closed_R: Optional[bool] = None, ): self.lower: Number self.upper: Number self.closed_L: bool self.closed_R: bool self.singleton_set: Optional[set[Number]] = None try: if len(args) == 0: ( self.lower, self.upper, self.closed_L, self.closed_R, self.singleton_set, ) = _EMPTY_INTERVAL_ARGS return # Handle different types of input arguments # Check for numeric types first to allow proper type narrowing if len(args) == 1 and isinstance(args[0], (int, float)): # a singleton, but this will be handled later self.lower = args[0] self.upper = args[0] default_closed = False elif len(args) == 1 and isinstance(args[0], (list, tuple)): assert len(args[0]) == 2, ( "if arg is a list or tuple, it must have length 2" ) self.lower = args[0][0] self.upper = args[0][1] # Determine closure type based on the container type default_closed = isinstance(args[0], list) elif len(args) == 2: self.lower, self.upper = args # type: ignore[assignment] default_closed = False # Default to open interval if two args else: raise ValueError(f"Invalid input arguments: {args}") # if both of the bounds are NaN or None, return an empty interval if any(x is None for x in (self.lower, self.upper)) or any( math.isnan(x) for x in (self.lower, self.upper) ): if (self.lower is None and self.upper is None) or ( math.isnan(self.lower) and math.isnan(self.upper) ): ( self.lower, self.upper, self.closed_L, self.closed_R, self.singleton_set, ) = _EMPTY_INTERVAL_ARGS return else: raise ValueError( "Both bounds must be NaN or None to create an empty interval. Also, just use `Interval.get_empty()` instead." ) # Ensure lower bound is less than upper bound # TYPING: ty throws a @Todo here # Operator `>` is not supported for types `Sequence[@Todo]` and `Sequence[@Todo]`, in comparing `@Todo | Sequence[@Todo]` with `@Todo | Sequence[@Todo]`tyunsupported-operator if self.lower > self.upper: # type: ignore[unsupported-operator] raise ValueError("Lower bound must be less than upper bound") if math.isnan(self.lower) or math.isnan(self.upper): raise ValueError("NaN is not allowed as an interval bound") # Determine closure properties if is_closed is not None: # can't specify both is_closed and closed_L/R if (closed_L is not None) or (closed_R is not None): raise ValueError("Cannot specify both is_closed and closed_L/R") self.closed_L = is_closed self.closed_R = is_closed else: self.closed_L = closed_L if closed_L is not None else default_closed self.closed_R = closed_R if closed_R is not None else default_closed # handle singleton/empty case if self.lower == self.upper and not (self.closed_L or self.closed_R): ( self.lower, self.upper, self.closed_L, self.closed_R, self.singleton_set, ) = _EMPTY_INTERVAL_ARGS return elif self.lower == self.upper and (self.closed_L or self.closed_R): self.singleton_set = {self.lower} # Singleton interval self.closed_L = True self.closed_R = True return # otherwise `singleton_set` is `None` except (AssertionError, ValueError) as e: raise ValueError( f"Invalid input arguments to Interval: {args = }, {is_closed = }, {closed_L = }, {closed_R = }\n{e}\nUsage:\n{self.__doc__}" ) from e @property def is_closed(self) -> bool: if self.is_empty: return True if self.is_singleton: return True return self.closed_L and self.closed_R @property def is_open(self) -> bool: if self.is_empty: return True if self.is_singleton: return False return not self.closed_L and not self.closed_R @property def is_half_open(self) -> bool: return (self.closed_L and not self.closed_R) or ( not self.closed_L and self.closed_R ) @property def is_singleton(self) -> bool: return self.singleton_set is not None and len(self.singleton_set) == 1 @property def is_empty(self) -> bool: return self.singleton_set is not None and len(self.singleton_set) == 0 @property def is_finite(self) -> bool: return not math.isinf(self.lower) and not math.isinf(self.upper) @property def singleton(self) -> Number: if not self.is_singleton: raise ValueError("Interval is not a singleton") return next(iter(self.singleton_set)) # type: ignore[arg-type] @staticmethod def get_empty() -> Interval: return Interval(math.nan, math.nan, closed_L=None, closed_R=None) @staticmethod def get_singleton(value: Number) -> Interval: if math.isnan(value) or value is None: return Interval.get_empty() return Interval(value, value, closed_L=True, closed_R=True) def numerical_contained(self, item: Number) -> bool: if self.is_empty: return False if math.isnan(item): raise ValueError("NaN cannot be checked for containment in an interval") if self.is_singleton: return item in self.singleton_set # type: ignore[operator] return ((self.closed_L and item >= self.lower) or item > self.lower) and ( (self.closed_R and item <= self.upper) or item < self.upper ) def interval_contained(self, item: Interval) -> bool: if item.is_empty: return True if self.is_empty: return False if item.is_singleton: return self.numerical_contained(item.singleton) if self.is_singleton: if not item.is_singleton: return False return self.singleton == item.singleton lower_contained: bool = ( # either strictly wider bound self.lower < item.lower # if same, then self must be closed if item is open or (self.lower == item.lower and self.closed_L >= item.closed_L) ) upper_contained: bool = ( # either strictly wider bound self.upper > item.upper # if same, then self must be closed if item is open or (self.upper == item.upper and self.closed_R >= item.closed_R) ) return lower_contained and upper_contained def __contains__(self, item: Any) -> bool: if isinstance(item, Interval): return self.interval_contained(item) else: return self.numerical_contained(item) def __repr__(self) -> str: if self.is_empty: return r"∅" if self.is_singleton: return "{" + str(self.singleton) + "}" left: str = "[" if self.closed_L else "(" right: str = "]" if self.closed_R else ")" return f"{left}{self.lower}, {self.upper}{right}" def __str__(self) -> str: return repr(self) @classmethod def from_str(cls, input_str: str) -> Interval: input_str = input_str.strip() # empty and singleton if input_str.count(",") == 0: # empty set if input_str == "∅": return cls.get_empty() assert input_str.startswith("{") and input_str.endswith("}"), ( "Invalid input string" ) input_str_set_interior: str = input_str.strip("{}").strip() if len(input_str_set_interior) == 0: return cls.get_empty() # singleton set return cls.get_singleton(str_to_numeric(input_str_set_interior)) # expect commas if not input_str.count(",") == 1: raise ValueError("Invalid input string") # get bounds lower: str upper: str lower, upper = input_str.strip("[]()").split(",") lower = lower.strip() upper = upper.strip() lower_num: Number = str_to_numeric(lower) upper_num: Number = str_to_numeric(upper) # figure out closure closed_L: bool closed_R: bool if input_str[0] == "[": closed_L = True elif input_str[0] == "(": closed_L = False else: raise ValueError("Invalid input string") if input_str[-1] == "]": closed_R = True elif input_str[-1] == ")": closed_R = False else: raise ValueError("Invalid input string") return cls(lower_num, upper_num, closed_L=closed_L, closed_R=closed_R) def __eq__(self, other: object) -> bool: if not isinstance(other, Interval): return False if self.is_empty and other.is_empty: return True if self.is_singleton and other.is_singleton: return self.singleton == other.singleton return (self.lower, self.upper, self.closed_L, self.closed_R) == ( other.lower, other.upper, other.closed_L, other.closed_R, ) def __iter__(self): if self.is_empty: return elif self.is_singleton: yield self.singleton return else: yield self.lower yield self.upper def __getitem__(self, index: int) -> float: if self.is_empty: raise IndexError("Empty interval has no bounds") if self.is_singleton: if index == 0: return self.singleton else: raise IndexError("Singleton interval has only one bound") if index == 0: return self.lower elif index == 1: return self.upper else: raise IndexError("Interval index out of range") def __len__(self) -> int: return 0 if self.is_empty else 1 if self.is_singleton else 2 def copy(self) -> Interval: if self.is_empty: return Interval.get_empty() if self.is_singleton: return Interval.get_singleton(self.singleton) return Interval( self.lower, self.upper, closed_L=self.closed_L, closed_R=self.closed_R ) def size(self) -> float: """ Returns the size of the interval. # Returns: - `float` the size of the interval """ if self.is_empty or self.is_singleton: return 0 else: return self.upper - self.lower def clamp(self, value: Union[int, float], epsilon: float = _EPSILON) -> float: """ Clamp the given value to the interval bounds. For open bounds, the clamped value will be slightly inside the interval (by epsilon). # Parameters: - `value : Union[int, float]` the value to clamp. - `epsilon : float` margin for open bounds (defaults to `_EPSILON`) # Returns: - `float` the clamped value # Raises: - `ValueError` : If the input value is NaN. """ if math.isnan(value): raise ValueError("Cannot clamp NaN value") if math.isnan(epsilon): raise ValueError("Epsilon cannot be NaN") if epsilon < 0: raise ValueError(f"Epsilon must be non-negative: {epsilon = }") if self.is_empty: raise ValueError("Cannot clamp to an empty interval") if self.is_singleton: return self.singleton if epsilon > self.size(): raise ValueError( f"epsilon is greater than the size of the interval: {epsilon = }, {self.size() = }, {self = }" ) # make type work with decimals and stuff if not isinstance(value, (int, float)): epsilon = value.__class__(epsilon) clamped_min: Number if self.closed_L: clamped_min = self.lower else: clamped_min = self.lower + epsilon clamped_max: Number if self.closed_R: clamped_max = self.upper else: clamped_max = self.upper - epsilon return max(clamped_min, min(value, clamped_max)) def intersection(self, other: Interval) -> Interval: if not isinstance(other, Interval): raise TypeError("Can only intersect with another Interval") if self.is_empty or other.is_empty: return Interval.get_empty() if self.is_singleton: if other.numerical_contained(self.singleton): return self.copy() else: return Interval.get_empty() if other.is_singleton: if self.numerical_contained(other.singleton): return other.copy() else: return Interval.get_empty() if self.upper < other.lower or other.upper < self.lower: return Interval.get_empty() lower: Number = max(self.lower, other.lower) upper: Number = min(self.upper, other.upper) closed_L: bool = self.closed_L if self.lower > other.lower else other.closed_L closed_R: bool = self.closed_R if self.upper < other.upper else other.closed_R return Interval(lower, upper, closed_L=closed_L, closed_R=closed_R) def union(self, other: Interval) -> Interval: if not isinstance(other, Interval): raise TypeError("Can only union with another Interval") # empty set case if self.is_empty: return other.copy() if other.is_empty: return self.copy() # special case where the intersection is empty but the intervals are contiguous if self.upper == other.lower: if self.closed_R or other.closed_L: return Interval( self.lower, other.upper, closed_L=self.closed_L, closed_R=other.closed_R, ) elif other.upper == self.lower: if other.closed_R or self.closed_L: return Interval( other.lower, self.upper, closed_L=other.closed_L, closed_R=self.closed_R, ) # non-intersecting nonempty and non-contiguous intervals if self.intersection(other) == Interval.get_empty(): raise NotImplementedError( "Union of non-intersecting nonempty non-contiguous intervals is not implemented " + f"{self = }, {other = }, {self.intersection(other) = }" ) # singleton case if self.is_singleton: return other.copy() if other.is_singleton: return self.copy() # regular case lower: Number = min(self.lower, other.lower) upper: Number = max(self.upper, other.upper) closed_L: bool = self.closed_L if self.lower < other.lower else other.closed_L closed_R: bool = self.closed_R if self.upper > other.upper else other.closed_R return Interval(lower, upper, closed_L=closed_L, closed_R=closed_R) class ClosedInterval(Interval): def __init__(self, *args: Union[Sequence[float], float], **kwargs: Any): if any(key in kwargs for key in ("is_closed", "closed_L", "closed_R")): raise ValueError("Cannot specify closure properties for ClosedInterval") super().__init__(*args, is_closed=True) class OpenInterval(Interval): def __init__(self, *args: Union[Sequence[float], float], **kwargs: Any): if any(key in kwargs for key in ("is_closed", "closed_L", "closed_R")): raise ValueError("Cannot specify closure properties for OpenInterval") super().__init__(*args, is_closed=False) ``````{ end_of_file="muutils/interval.py" } ``````{ path="muutils/jsonlines.py" } "utilities for reading and writing jsonlines files, including gzip support" from __future__ import annotations import gzip import json from typing import Callable, Sequence from muutils.json_serialize import JSONitem _GZIP_EXTENSIONS: tuple = (".gz", ".gzip") def _file_is_gzip(path: str) -> bool: return any(str(path).endswith(ext) for ext in _GZIP_EXTENSIONS) def _get_opener( path: str, use_gzip: bool | None = None, ) -> Callable: if use_gzip is None: use_gzip = _file_is_gzip(path) # appears to be another mypy bug # https://github.com/python/mypy/issues/10740 return open if not use_gzip else gzip.open # type: ignore def jsonl_load( path: str, /, *, use_gzip: bool | None = None, ) -> list[JSONitem]: opener: Callable = _get_opener(path, use_gzip) data: list[JSONitem] = list() with opener(path, "rt", encoding="UTF-8") as f: for line in f: data.append(json.loads(line)) return data def jsonl_load_log( path: str, /, *, use_gzip: bool | None = None, ) -> list[dict]: data: list[JSONitem] = jsonl_load(path, use_gzip=use_gzip) for idx, item in enumerate(data): assert isinstance(item, dict), ( f"item {idx = } from file {path} is not a dict: {type(item) = }\t{item = }" ) # mypy complains that we are returning a list[JSONitem] but the function signature says list[dict] # it can't figure out that we are asserting that all items are dicts return data # type: ignore def jsonl_write( path: str, items: Sequence[JSONitem], use_gzip: bool | None = None, gzip_compresslevel: int = 2, ) -> None: opener: Callable = _get_opener(path, use_gzip) opener_kwargs: dict = dict() if use_gzip: opener_kwargs = dict(compresslevel=gzip_compresslevel) with opener(path, "wt", encoding="UTF-8", **opener_kwargs) as f: for item in items: f.write(json.dumps(item) + "\n") ``````{ end_of_file="muutils/jsonlines.py" } ``````{ path="muutils/kappa.py" } """anonymous getitem class util for constructing a class which has a getitem method which just calls a function a `lambda` is an anonymous function: kappa is the letter before lambda in the greek alphabet, hence the name of this class""" from __future__ import annotations from typing import Callable, Final, Mapping, TypeVar _kappa_K = TypeVar("_kappa_K") _kappa_V = TypeVar("_kappa_V") # get the docstring of this file _BASE_DOC: Final[str] = ( # TYPING: type checkers complain here, they have no idea that this module does in fact have a __doc__ __doc__ or "anonymous getitem class" + """ source function docstring: ==============================\n """ ) class Kappa(Mapping[_kappa_K, _kappa_V]): def __init__(self, func_getitem: Callable[[_kappa_K], _kappa_V]) -> None: self.func_getitem = func_getitem self.doc = _BASE_DOC + str( getattr( func_getitem, "__doc__", "" ) ) def __getitem__(self, x: _kappa_K) -> _kappa_V: return self.func_getitem(x) def __iter__(self) -> None: # type: ignore[override] raise NotImplementedError( "This method is not implemented for Kappa, we don't know the valid inputs" ) def __len__(self) -> int: raise NotImplementedError( "This method is not implemented for Kappa, no idea how many valid inputs there are" ) ``````{ end_of_file="muutils/kappa.py" } ``````{ path="muutils/mlutils.py" } "miscellaneous utilities for ML pipelines" from __future__ import annotations import json import os import random import typing import warnings from itertools import islice from pathlib import Path from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, Union ARRAY_IMPORTS: bool try: import numpy as np import torch import torch.backends.mps ARRAY_IMPORTS = True except ImportError as e: warnings.warn( f"Numpy or torch not installed. Array operations will not be available.\n{e}" ) ARRAY_IMPORTS = False DEFAULT_SEED: int = 42 GLOBAL_SEED: int = DEFAULT_SEED def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" if not ARRAY_IMPORTS: raise ImportError( "Numpy or torch not installed. Array operations will not be available." ) try: # if device is given assert torch, "Torch is not available, cannot get device" # pyright: ignore[reportPossiblyUnboundVariable] if device is not None: device = torch.device(device) if any( [ torch.cuda.is_available() and device.type == "cuda", torch.backends.mps.is_available() and device.type == "mps", device.type == "cpu", ] ): # if device is given and available pass else: warnings.warn( f"Specified device {device} is not available, falling back to CPU" ) return torch.device("cpu") # no device given, infer from availability else: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") # put a dummy tensor on the device to check if it is available _dummy = torch.zeros(1, device=device) return device except Exception as e: warnings.warn( f"Error while getting device, falling back to CPU. Error: {e}", RuntimeWarning, ) return torch.device("cpu") # pyright: ignore[reportPossiblyUnboundVariable] def set_reproducibility(seed: int = DEFAULT_SEED): """ Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information. Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance. """ global GLOBAL_SEED GLOBAL_SEED = seed random.seed(seed) if ARRAY_IMPORTS: try: assert np, "Numpy is not available, cannot set seed for numpy" # pyright: ignore[reportPossiblyUnboundVariable] np.random.seed(seed) except Exception as e: warnings.warn(f"Error while setting seed for numpy: {e}", RuntimeWarning) try: assert torch, "Torch is not available, cannot set seed for torch" # pyright: ignore[reportPossiblyUnboundVariable] torch.manual_seed(seed) torch.use_deterministic_algorithms(True) except Exception as e: warnings.warn(f"Error while setting seed for torch: {e}", RuntimeWarning) # Ensure reproducibility for concurrent CUDA streams # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" T = TypeVar("T") def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]: """Yield successive chunks from an iterator.""" # https://stackoverflow.com/a/61435714 iterator = iter(it) while chunk := list(islice(iterator, chunk_size)): yield chunk def get_checkpoint_paths_for_run( run_path: Path, extension: typing.Literal["pt", "zanj"], checkpoints_format: str = "checkpoints/model.iter_*.{extension}", ) -> list[tuple[int, Path]]: """get checkpoints of the format from the run_path note that `checkpoints_format` should contain a glob pattern with: - unresolved "{extension}" format term for the extension - a wildcard for the iteration number """ assert run_path.is_dir(), ( f"Model path {run_path} is not a directory (expect run directory, not model files)" ) return [ (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) for checkpoint_path in sorted( Path(run_path).glob(checkpoints_format.format(extension=extension)) ) ] F = TypeVar("F", bound=Callable[..., Any]) def register_method( method_dict: dict[str, Callable[..., Any]], custom_name: Optional[str] = None, ) -> Callable[[F], F]: """Decorator to add a method to the method_dict""" def decorator(method: F) -> F: method_name: str if custom_name is None: method_name_orig: str | None = getattr(method, "__name__", None) if method_name_orig is None: warnings.warn( f"Method {method} does not have a name, using sanitized repr" ) from muutils.misc import sanitize_identifier method_name = sanitize_identifier(repr(method)) else: method_name = method_name_orig else: method_name = custom_name # TYPING: ty complains here method.__name__ = custom_name # type: ignore[unresolved-attribute] assert method_name not in method_dict, ( f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" ) method_dict[method_name] = method return method return decorator def pprint_summary(summary: dict): print(json.dumps(summary, indent=2)) ``````{ end_of_file="muutils/mlutils.py" } ``````{ path="muutils/parallel.py" } "parallel processing utilities, chiefly `run_maybe_parallel`" from __future__ import annotations import multiprocessing import functools from typing import ( Any, Callable, Iterable, Literal, Optional, Tuple, TypeVar, Dict, List, Union, Protocol, ) # for no tqdm fallback from muutils.spinner import SpinnerContext from muutils.validate_type import get_fn_allowed_kwargs InputType = TypeVar("InputType") OutputType = TypeVar("OutputType") # typevars for our iterable and map class ProgressBarFunction(Protocol): "a protocol for a progress bar function" def __call__(self, iterable: Iterable[Any], **kwargs: Any) -> Iterable[Any]: ... ProgressBarOption = Literal["tqdm", "spinner", "none", None] # type for the progress bar option DEFAULT_PBAR_FN: ProgressBarOption # default progress bar function try: # use tqdm if it's available import tqdm DEFAULT_PBAR_FN = "tqdm" except ImportError: # use progress bar as fallback DEFAULT_PBAR_FN = "spinner" def spinner_fn_wrap(x: Iterable[Any], **kwargs: Any) -> List[Any]: "spinner wrapper" spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( SpinnerContext.__init__ ) mapped_kwargs: dict = { k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs } if "desc" in kwargs and "message" not in mapped_kwargs: mapped_kwargs["message"] = kwargs["desc"] if "message" not in mapped_kwargs and "total" in kwargs: mapped_kwargs["message"] = f"Processing {kwargs['total']} items" with SpinnerContext(**mapped_kwargs): output = list(x) return output def map_kwargs_for_tqdm(kwargs: Dict[str, Any]) -> Dict[str, Any]: "map kwargs for tqdm, cant wrap because the pbar dissapears?" tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) # pyright: ignore[reportPossiblyUnboundVariable] mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} if "desc" not in kwargs: if "message" in kwargs: mapped_kwargs["desc"] = kwargs["message"] elif "total" in kwargs: mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" return mapped_kwargs def no_progress_fn_wrap(x: Iterable[Any], **kwargs: Any) -> Iterable[Any]: "fallback to no progress bar" return x def set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, ProgressBarOption], pbar_kwargs: Optional[Dict[str, Any]] = None, **extra_kwargs: Any, ) -> Tuple[ProgressBarFunction, Dict[str, Any]]: """set up the progress bar function and its kwargs # Parameters: - `pbar : Union[ProgressBarFunction, ProgressBarOption]` progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use - `pbar_kwargs : Optional[Dict[str, Any]]` kwargs passed to the progress bar function (default to `None`) (defaults to `None`) # Returns: - `Tuple[ProgressBarFunction, dict]` a tuple of the progress bar function and its kwargs # Raises: - `ValueError` : if `pbar` is not one of the valid options """ pbar_fn: ProgressBarFunction if pbar_kwargs is None: pbar_kwargs = dict() pbar_kwargs = {**extra_kwargs, **pbar_kwargs} # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): pbar_fn = no_progress_fn_wrap # type: ignore[assignment] # if `pbar` is a different string, figure out which progress bar to use elif isinstance(pbar, str): if pbar == "tqdm": pbar_fn = tqdm.tqdm # pyright: ignore[reportPossiblyUnboundVariable] pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) elif pbar == "spinner": pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) pbar_kwargs = dict() else: raise ValueError( f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" ) else: # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this pbar_fn = pbar return pbar_fn, pbar_kwargs # TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes def run_maybe_parallel( func: Callable[[InputType], OutputType], iterable: Iterable[InputType], parallel: Union[bool, int], pbar_kwargs: Optional[Dict[str, Any]] = None, chunksize: Optional[int] = None, keep_ordered: bool = True, use_multiprocess: bool = False, pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, ) -> List[OutputType]: """a function to make it easier to sometimes parallelize an operation - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` # Parameters: - `func : Callable[[InputType], OutputType]` function passed to either `map` or `Pool.imap` - `iterable : Iterable[InputType]` iterable passed to either `map` or `Pool.imap` - `parallel : bool | int` whether to run in parallel, and how many processes to use - `pbar_kwargs : Dict[str, Any]` kwargs passed to the progress bar function # Returns: - `List[OutputType]` a list of the output of `func` for each element in `iterable` # Raises: - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 - `ValueError` : if `use_multiprocess=True` and `parallel=False` - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available """ # number of inputs in iterable n_inputs: int = len(iterable) # type: ignore[arg-type] if n_inputs == 0: # Return immediately if there is no input return list() # which progress bar to use pbar_fn: ProgressBarFunction pbar_kwargs_processed: dict pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( pbar=pbar, pbar_kwargs=pbar_kwargs, # extra kwargs total=n_inputs, ) # number of processes num_processes: int if isinstance(parallel, bool): num_processes = multiprocessing.cpu_count() if parallel else 1 elif isinstance(parallel, int): if parallel < 2: raise ValueError( f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" ) num_processes = parallel else: raise ValueError( f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" ) # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process num_processes = min(num_processes, n_inputs) mp = multiprocessing if num_processes == 1: parallel = False if use_multiprocess: if not parallel: raise ValueError("`use_multiprocess=True` requires `parallel=True`") try: import multiprocess # type: ignore[import-untyped] except ImportError as e: raise ImportError( "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" ) from e mp = multiprocess # set up the map function -- maybe its parallel, maybe it's just `map` do_map: Callable[ [Callable[[InputType], OutputType], Iterable[InputType]], Iterable[OutputType], ] if parallel: # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` # TYPING: messy here pool = mp.Pool(num_processes) # type: ignore[possibly-missing-attribute] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType] # use `imap` if we want to keep the order, otherwise use `imap_unordered` if keep_ordered: do_map = pool.imap else: do_map = pool.imap_unordered # figure out a smart chunksize if one is not given chunksize_int: int if chunksize is None: chunksize_int = max(1, n_inputs // num_processes) else: chunksize_int = chunksize # set the chunksize do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore else: do_map = map # run the map function with a progress bar output: List[OutputType] = list( pbar_fn( do_map( func, iterable, ), **pbar_kwargs_processed, ) ) # close the pool if we used one if parallel: pool.close() # pyright: ignore[reportPossiblyUnboundVariable] pool.join() # pyright: ignore[reportPossiblyUnboundVariable] # return the output as a list return output ``````{ end_of_file="muutils/parallel.py" } ``````{ path="muutils/py.typed" } ``````{ end_of_file="muutils/py.typed" } ``````{ path="muutils/spinner.py" } """decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner using the base `Spinner` class while some code is running. """ from __future__ import annotations import os import time from dataclasses import dataclass, field import threading import sys from functools import wraps from types import TracebackType from typing import ( List, Dict, Callable, Any, Literal, Optional, TextIO, TypeVar, Sequence, Union, ContextManager, ) import warnings DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any]) "Define a generic type for the decorated function" @dataclass class SpinnerConfig: working: List[str] = field(default_factory=lambda: ["|", "/", "-", "\\"]) success: str = "✔️" fail: str = "❌" def is_ascii(self) -> bool: "whether all characters are ascii" return all(s.isascii() for s in self.working + [self.success, self.fail]) def eq_lens(self) -> bool: "whether all working characters are the same length" expected_len: int = len(self.working[0]) return all( [ len(char) == expected_len for char in self.working + [self.success, self.fail] ] ) def is_valid(self) -> bool: "whether the spinner config is valid" return all( [ len(self.working) > 0, isinstance(self.working, list), isinstance(self.success, str), isinstance(self.fail, str), all(isinstance(char, str) for char in self.working), ] ) def __post_init__(self): if not self.is_valid(): raise ValueError(f"Invalid SpinnerConfig: {self}") @classmethod def from_any(cls, arg: "SpinnerConfigArg") -> "SpinnerConfig": # check SpinnerConfig first to help type narrowing if isinstance(arg, SpinnerConfig): return arg elif isinstance(arg, str): return SPINNERS[arg] elif isinstance(arg, list): return SpinnerConfig(working=arg) elif isinstance(arg, dict): return SpinnerConfig(**arg) else: raise TypeError( f"to create a SpinnerConfig, you must pass a string (key), list (working seq), dict (kwargs to SpinnerConfig), or SpinnerConfig, but got {type(arg) = }, {arg = }" ) SpinnerConfigArg = Union[str, List[str], SpinnerConfig, Dict[str, Any]] SPINNERS: Dict[str, SpinnerConfig] = dict( default=SpinnerConfig(working=["|", "/", "-", "\\"], success="#", fail="X"), dots=SpinnerConfig(working=[". ", ".. ", "..."], success="***", fail="xxx"), bars=SpinnerConfig(working=["| ", "|| ", "|||"], success="|||", fail="///"), arrows=SpinnerConfig(working=["<", "^", ">", "v"], success="►", fail="✖"), arrows_2=SpinnerConfig( working=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], success="→", fail="↯" ), bouncing_bar=SpinnerConfig( working=["[ ]", "[= ]", "[== ]", "[=== ]", "[ ===]", "[ ==]", "[ =]"], success="[====]", fail="[XXXX]", ), bar=SpinnerConfig( working=["[ ]", "[- ]", "[--]", "[ -]"], success="[==]", fail="[xx]", ), bouncing_ball=SpinnerConfig( working=[ "( ● )", "( ● )", "( ● )", "( ● )", "( ●)", "( ● )", "( ● )", "( ● )", "( ● )", "(● )", ], success="(●●●●●●)", fail="( ✖ )", ), ooo=SpinnerConfig(working=[".", "o", "O", "o"], success="O", fail="x"), braille=SpinnerConfig( working=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], success="⣿", fail="X", ), clock=SpinnerConfig( working=[ "🕛", "🕐", "🕑", "🕒", "🕓", "🕔", "🕕", "🕖", "🕗", "🕘", "🕙", "🕚", ], success="✔️", fail="❌", ), hourglass=SpinnerConfig(working=["⏳", "⌛"], success="✔️", fail="❌"), square_corners=SpinnerConfig(working=["◰", "◳", "◲", "◱"], success="◼", fail="✖"), triangle=SpinnerConfig(working=["◢", "◣", "◤", "◥"], success="◆", fail="✖"), square_dot=SpinnerConfig( working=["⣷", "⣯", "⣟", "⡿", "⢿", "⣻", "⣽", "⣾"], success="⣿", fail="❌" ), box_bounce=SpinnerConfig(working=["▌", "▀", "▐", "▄"], success="■", fail="✖"), hamburger=SpinnerConfig(working=["☱", "☲", "☴"], success="☰", fail="✖"), earth=SpinnerConfig(working=["🌍", "🌎", "🌏"], success="✔️", fail="❌"), growing_dots=SpinnerConfig( working=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"], success="⣿", fail="✖" ), dice=SpinnerConfig(working=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"], success="🎲", fail="✖"), wifi=SpinnerConfig( working=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"], success="✔️", fail="❌" ), bounce=SpinnerConfig(working=["⠁", "⠂", "⠄", "⠂"], success="⠿", fail="⢿"), arc=SpinnerConfig(working=["◜", "◠", "◝", "◞", "◡", "◟"], success="○", fail="✖"), toggle=SpinnerConfig(working=["⊶", "⊷"], success="⊷", fail="⊗"), toggle2=SpinnerConfig(working=["▫", "▪"], success="▪", fail="✖"), toggle3=SpinnerConfig(working=["□", "■"], success="■", fail="✖"), toggle4=SpinnerConfig(working=["■", "□", "▪", "▫"], success="■", fail="✖"), toggle5=SpinnerConfig(working=["▮", "▯"], success="▮", fail="✖"), toggle7=SpinnerConfig(working=["⦾", "⦿"], success="⦿", fail="✖"), toggle8=SpinnerConfig(working=["◍", "◌"], success="◍", fail="✖"), toggle9=SpinnerConfig(working=["◉", "◎"], success="◉", fail="✖"), arrow2=SpinnerConfig( working=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "], success="➡️", fail="❌" ), point=SpinnerConfig( working=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"], success="●●●", fail="xxx" ), layer=SpinnerConfig(working=["-", "=", "≡"], success="≡", fail="✖"), speaker=SpinnerConfig( working=["🔈 ", "🔉 ", "🔊 ", "🔉 "], success="🔊", fail="🔇" ), orangePulse=SpinnerConfig( working=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "], success="🟠", fail="❌" ), bluePulse=SpinnerConfig( working=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "], success="🔵", fail="❌" ), satellite_signal=SpinnerConfig( working=["📡 ", "📡· ", "📡·· ", "📡···", "📡 ··", "📡 ·"], success="📡 ✔️ ", fail="📡 ❌ ", ), rocket_orbit=SpinnerConfig( working=["🌍🚀 ", "🌏 🚀 ", "🌎 🚀"], success="🌍 ✨", fail="🌍 💥" ), ogham=SpinnerConfig(working=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"], success="᚛᚜", fail="✖"), eth=SpinnerConfig( working=["᛫", "፡", "፥", "፤", "፧", "።", "፨"], success="፠", fail="✖" ), ) # spinner configurations class Spinner: """displays a spinner, and optionally elapsed time and a mutable value while a function is running. # Parameters: - `update_interval : float` how often to update the spinner display in seconds (defaults to `0.1`) - `initial_value : str` initial value to display with the spinner (defaults to `""`) - `message : str` message to display with the spinner (defaults to `""`) - `format_string : str` string to format the spinner with. must have `"\\r"` prepended to clear the line. allowed keys are `spinner`, `elapsed_time`, `message`, and `value` (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`) - `output_stream : TextIO` stream to write the spinner to (defaults to `sys.stdout`) - `format_string_when_updated : Union[bool,str]` whether to use a different format string when the value is updated. if `True`, use the default format string with a newline appended. if a string, use that string. this is useful if you want update_value to print to console and be preserved. (defaults to `False`) # Deprecated Parameters: - `spinner_chars : Union[str, Sequence[str]]` sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters (defaults to `"default"`) - `spinner_complete : str` string to display when the spinner is complete (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`) # Methods: - `update_value(value: Any) -> None` update the current value displayed by the spinner # Usage: ## As a context manager: ```python with SpinnerContext() as sp: for i in range(1): time.sleep(0.1) spinner.update_value(f"Step {i+1}") ``` ## As a decorator: ```python @spinner_decorator def long_running_function(): for i in range(1): time.sleep(0.1) spinner.update_value(f"Step {i+1}") return "Function completed" ``` """ def __init__( self, # no positional args *args: Any, config: SpinnerConfigArg = "default", update_interval: float = 0.1, initial_value: str = "", message: str = "", format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}", output_stream: TextIO = sys.stdout, format_string_when_updated: Union[str, bool] = False, # deprecated spinner_chars: Optional[Union[str, Sequence[str]]] = None, spinner_complete: Optional[str] = None, # no other kwargs accepted **kwargs: Any, ): if args: raise ValueError(f"Spinner does not accept positional arguments: {args}") if kwargs: raise ValueError( f"Spinner did not recognize these keyword arguments: {kwargs}" ) # old spinner display if (spinner_chars is not None) or (spinner_complete is not None): warnings.warn( "spinner_chars and spinner_complete are deprecated and will have no effect. Use `config` instead.", DeprecationWarning, ) # config self.config: SpinnerConfig = SpinnerConfig.from_any(config) # special format string for when the value is updated self.format_string_when_updated: Optional[str] = None "format string to use when the value is updated" if format_string_when_updated is not False: if format_string_when_updated is True: # modify the default format string self.format_string_when_updated = format_string + "\n" elif isinstance(format_string_when_updated, str): # use the provided format string self.format_string_when_updated = format_string_when_updated else: raise TypeError( "format_string_when_updated must be a string or True, got" + f" {type(format_string_when_updated) = }{format_string_when_updated}" ) # copy other kwargs self.update_interval: float = update_interval self.message: str = message self.current_value: Any = initial_value self.format_string: str = format_string self.output_stream: TextIO = output_stream # test out format string try: self.format_string.format( spinner=self.config.working[0], elapsed_time=0.0, message=self.message, value=self.current_value, ) except Exception as e: raise ValueError( f"Invalid format string: {format_string}. Must take keys " + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'." ) from e # init self.start_time: float = 0 "for measuring elapsed time" self.stop_spinner: threading.Event = threading.Event() "to stop the spinner" self.spinner_thread: Optional[threading.Thread] = None "the thread running the spinner" self.value_changed: bool = False "whether the value has been updated since the last display" self.term_width: int "width of the terminal, for padding with spaces" try: self.term_width = os.get_terminal_size().columns except OSError: self.term_width = 80 # state of the spinner self.state: Literal["initialized", "running", "success", "fail"] = "initialized" def spin(self) -> None: "Function to run in a separate thread, displaying the spinner and optional information" i: int = 0 while not self.stop_spinner.is_set(): # get current spinner str spinner: str = self.config.working[i % len(self.config.working)] # args for display string display_parts: Dict[str, Any] = dict( spinner=spinner, # str elapsed_time=time.time() - self.start_time, # float message=self.message, # str value=self.current_value, # Any, but will be formatted as str ) # use the special one if needed format_str: str = self.format_string if self.value_changed and (self.format_string_when_updated is not None): self.value_changed = False format_str = self.format_string_when_updated # write and flush the display string output: str = format_str.format(**display_parts).ljust(self.term_width) self.output_stream.write(output) self.output_stream.flush() # wait for the next update time.sleep(self.update_interval) i += 1 def update_value(self, value: Any) -> None: "Update the current value displayed by the spinner" self.current_value = value self.value_changed = True def start(self) -> None: "Start the spinner" self.start_time = time.time() self.spinner_thread = threading.Thread(target=self.spin) self.spinner_thread.start() self.state = "running" def stop(self, failed: bool = False) -> None: "Stop the spinner" self.output_stream.write( self.format_string.format( spinner=self.config.success if not failed else self.config.fail, elapsed_time=time.time() - self.start_time, # float message=self.message, # str value=self.current_value, # Any, but will be formatted as str ).ljust(self.term_width) ) self.stop_spinner.set() if self.spinner_thread: self.spinner_thread.join() self.output_stream.write("\n") self.output_stream.flush() self.state = "fail" if failed else "success" class NoOpContextManager(ContextManager): # type: ignore[type-arg] """A context manager that does nothing.""" def __init__(self, *args: Any, **kwargs: Any) -> None: pass def __enter__(self) -> NoOpContextManager: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: pass class SpinnerContext(Spinner, ContextManager): "see `Spinner` for parameters" def __enter__(self) -> "SpinnerContext": self.start() return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.stop(failed=exc_type is not None) SpinnerContext.__doc__ = Spinner.__doc__ # TODO: type hint that the `update_status` kwarg is not needed when calling the function we just decorated def spinner_decorator( *args: Any, # passed to `Spinner.__init__` config: SpinnerConfigArg = "default", update_interval: float = 0.1, initial_value: str = "", message: str = "", format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}", output_stream: TextIO = sys.stdout, # new kwarg mutable_kwarg_key: Optional[str] = None, # deprecated spinner_chars: Union[str, Sequence[str], None] = None, spinner_complete: Optional[str] = None, **kwargs: Any, ) -> Callable[[DecoratedFunction], DecoratedFunction]: """see `Spinner` for parameters. Also takes `mutable_kwarg_key` `mutable_kwarg_key` is the key with which `Spinner().update_value` will be passed to the decorated function. if `None`, won't pass it. """ if len(args) > 1: raise ValueError( f"spinner_decorator does not accept positional arguments: {args}" ) if kwargs: raise ValueError( f"spinner_decorator did not recognize these keyword arguments: {kwargs}" ) def decorator(func: DecoratedFunction) -> DecoratedFunction: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: spinner: Spinner = Spinner( config=config, update_interval=update_interval, initial_value=initial_value, message=message, format_string=format_string, output_stream=output_stream, spinner_chars=spinner_chars, spinner_complete=spinner_complete, ) if mutable_kwarg_key: kwargs[mutable_kwarg_key] = spinner.update_value spinner.start() try: result: Any = func(*args, **kwargs) spinner.stop(failed=False) except Exception as e: spinner.stop(failed=True) raise e return result # TODO: fix this type ignore return wrapper # type: ignore[return-value] if not args: # called as `@spinner_decorator(stuff)` return decorator else: # called as `@spinner_decorator` without parens return decorator(args[0]) spinner_decorator.__doc__ = Spinner.__doc__ ``````{ end_of_file="muutils/spinner.py" } ``````{ path="muutils/statcounter.py" } """`StatCounter` class for counting and calculating statistics on numbers cleaner and more efficient than just using a `Counter` or array""" from __future__ import annotations import json import math from collections import Counter from functools import cached_property from itertools import chain from typing import Any, Callable, Optional, Sequence, Union # _GeneralArray = Union[np.ndarray, "torch.Tensor"] NumericSequence = Sequence[Union[float, int, "NumericSequence"]] # pylint: disable=abstract-method # misc # ================================================== def universal_flatten( arr: Union[NumericSequence, float, int], require_rectangular: bool = True ) -> NumericSequence: """flattens any iterable""" # mypy complains that the sequence has no attribute "flatten" if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore return arr.flatten() # type: ignore elif isinstance(arr, Sequence): elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr] if require_rectangular and (all(elements_iterable) != any(elements_iterable)): raise ValueError("arr contains mixed iterable and non-iterable elements") if any(elements_iterable): return list(chain.from_iterable(universal_flatten(x) for x in arr)) # type: ignore[misc] else: return arr else: return [arr] # StatCounter # ================================================== class StatCounter(Counter): """`Counter`, but with some stat calculation methods which assume the keys are numerical works best when the keys are `int`s """ def validate(self) -> bool: """validate the counter as being all floats or ints""" return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys()) def min(self): "minimum value" return min(x for x, v in self.items() if v > 0) def max(self): "maximum value" return max(x for x, v in self.items() if v > 0) def total(self): """Sum of the counts""" return sum(self.values()) @cached_property def keys_sorted(self) -> list: """return the keys""" return sorted(list(self.keys())) def percentile(self, p: float): """return the value at the given percentile this could be log time if we did binary search, but that would be a lot of added complexity """ if p < 0 or p > 1: raise ValueError(f"percentile must be between 0 and 1: {p}") # flip for speed sorted_keys: list[float] = [float(x) for x in self.keys_sorted] sort: int = 1 if p > 0.51: sort = -1 p = 1 - p sorted_keys = sorted_keys[::sort] real_target: float = p * (self.total() - 1) n_target_f: int = math.floor(real_target) n_target_c: int = math.ceil(real_target) n_sofar: float = -1 # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }') for i, k in enumerate(sorted_keys): n_sofar += self[k] # print(f'{k = } {n_sofar = }') if n_sofar > n_target_f: return k elif n_sofar == n_target_f: if n_sofar == n_target_c: return k else: # print( # sorted_keys[i], (n_sofar + 1 - real_target), # sorted_keys[i + 1], (real_target - n_sofar), # ) return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[ i + 1 ] * (real_target - n_sofar) else: continue raise ValueError(f"percentile {p} not found???") def median(self) -> float: return self.percentile(0.5) def mean(self) -> float: """return the mean of the values""" return float(sum(k * c for k, c in self.items()) / self.total()) def mode(self) -> float: return self.most_common()[0][0] def std(self) -> float: """return the standard deviation of the values""" mean: float = self.mean() deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items()) return (deviations / self.total()) ** 0.5 def summary( self, typecast: Callable = lambda x: x, *, extra_percentiles: Optional[list[float]] = None, ) -> dict[str, Union[float, int]]: """return a summary of the stats, without the raw data. human readable and small""" # common stats that always work output: dict = dict( total_items=self.total(), n_keys=len(self.keys()), mode=self.mode(), ) if self.total() > 0: if self.validate(): # if its a numeric counter, we can do some stats output = { **output, **dict( mean=float(self.mean()), std=float(self.std()), min=typecast(self.min()), q1=typecast(self.percentile(0.25)), median=typecast(self.median()), q3=typecast(self.percentile(0.75)), max=typecast(self.max()), ), } if extra_percentiles is not None: for p in extra_percentiles: output[f"percentile_{p}"] = typecast(self.percentile(p)) else: # if its not, we can only do the simpler things # mean mode and total are done in the initial declaration of `output` pass return output def serialize( self, typecast: Callable = lambda x: x, *, extra_percentiles: Optional[list[float]] = None, ) -> dict: """return a json-serializable version of the counter includes both the output of `summary` and the raw data: ```json { "StatCounter": { }, "summary": self.summary(typecast, extra_percentiles=extra_percentiles), } """ return { "StatCounter": { typecast(k): v for k, v in sorted(dict(self).items(), key=lambda x: x[0]) }, "summary": self.summary(typecast, extra_percentiles=extra_percentiles), } def __str__(self) -> str: "summary as json with 2 space indent, good for printing" return json.dumps(self.summary(), indent=2) def __repr__(self) -> str: return json.dumps(self.serialize(), indent=2) @classmethod def load(cls, data: dict) -> "StatCounter": "load from a the output of `StatCounter.serialize`" if "StatCounter" in data: loadme = data["StatCounter"] else: loadme = data return cls({float(k): v for k, v in loadme.items()}) @classmethod def from_list_arrays( cls, arr: Any, map_func: Callable[[Any], float] = float, ) -> "StatCounter": """calls `map_func` on each element of `universal_flatten(arr)`""" return cls([map_func(x) for x in universal_flatten(arr)]) ``````{ end_of_file="muutils/statcounter.py" } ``````{ path="muutils/sysinfo.py" } "utilities for getting information about the system, see `SysInfo` class" from __future__ import annotations import subprocess import sys import typing from importlib.metadata import distributions def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]: p: subprocess.Popen[bytes] = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = p.communicate() p_out: typing.Union[str, list[str], None] if stdout: p_out = stdout.decode("utf-8") if split_out: assert isinstance(p_out, str) p_out = p_out.strip().split("\n") else: p_out = None return { "stdout": p_out, "stderr": stderr.decode("utf-8") if stderr else None, "returncode": p.returncode if p.returncode is None else int(p.returncode), } class SysInfo: """getters for various information about the system""" @staticmethod def python() -> dict: """details about python version""" ver_tup = sys.version_info return { "version": sys.version, "version_info": ver_tup, "major": ver_tup[0], "minor": ver_tup[1], "micro": ver_tup[2], "releaselevel": ver_tup[3], "serial": ver_tup[4], } @staticmethod def pip() -> dict: """installed packages info""" # in python <= 3.9 `Distribution` has no attribute `name` pckgs: list[tuple[str, str]] = [ ( ( x.metadata.get("Name", "") # type: ignore[attr-defined] if sys.version_info < (3, 10) else x.name # type: ignore[attr-defined] ), x.version, ) for x in distributions() ] return { "n_packages": len(pckgs), "packages": pckgs, } @staticmethod def pytorch() -> dict: """pytorch and cuda information""" try: import torch import torch.version except Exception as e: return { "importable": False, "error": str(e), } output: dict = {"importable": True} output["torch.__version__"] = torch.__version__ output["torch.version.cuda"] = torch.version.cuda output["torch.version.debug"] = torch.version.debug output["torch.version.git_version"] = torch.version.git_version output["torch.version.hip"] = torch.version.hip output["torch.cuda.is_available()"] = torch.cuda.is_available() output["torch.cuda.device_count()"] = torch.cuda.device_count() output["torch.cuda.is_initialized()"] = torch.cuda.is_initialized() if torch.cuda.is_available(): import os cuda_version_nvcc: str = os.popen("nvcc --version").read() output["nvcc --version"] = cuda_version_nvcc.split("\n") if torch.cuda.device_count() > 0: n_devices: int = torch.cuda.device_count() output["torch.cuda.current_device()"] = torch.cuda.current_device() output["torch devices"] = [] for current_device in range(n_devices): try: # print(f'checking current device {current_device} of {torch.cuda.device_count()} devices') # print(f'\tdevice {current_device}') # dev_prop = torch.cuda.get_device_properties(torch.device(0)) # print(f'\t name: {dev_prop.name}') # print(f'\t version: {dev_prop.major}.{dev_prop.minor}') # print(f'\t total_memory: {dev_prop.total_memory}') # print(f'\t multi_processor_count: {dev_prop.multi_processor_count}') # print(f'\t') dev_prop = torch.cuda.get_device_properties(current_device) output["torch devices"].append( { "device": current_device, "name": dev_prop.name, "version": { "major": dev_prop.major, "minor": dev_prop.minor, }, "total_memory": dev_prop.total_memory, "multi_processor_count": dev_prop.multi_processor_count, } ) except Exception as e: output["torch devices"].append( { "device": current_device, "error": str(e), } ) return output @staticmethod def platform() -> dict: import platform items = [ "platform", "machine", "processor", "system", "version", "architecture", "uname", "node", "python_branch", "python_build", "python_compiler", "python_implementation", ] return {x: getattr(platform, x)() for x in items} @staticmethod def git_info(with_log: bool = False) -> dict: git_version: dict = _popen(["git", "version"]) git_status: dict = _popen(["git", "status"]) if not git_status["stderr"] or git_status["stderr"].startswith( "fatal: not a git repository" ): return { "git version": git_version["stdout"], "git status": git_status, } else: output: dict = { "git version": git_version["stdout"], "git status": git_status, "git branch": _popen(["git", "branch"], split_out=True), "git remote -v": _popen(["git", "remote", "-v"], split_out=True), } if with_log: output["git log"] = _popen(["git", "log"], split_out=False) return output @classmethod def get_all( cls, include: typing.Optional[tuple[str, ...]] = None, exclude: tuple[str, ...] = tuple(), ) -> dict: include_meta: tuple[str, ...] if include is None: include_meta = tuple(cls.__dict__.keys()) else: include_meta = include return { x: getattr(cls, x)() for x in include_meta if all( [ not x.startswith("_"), x not in exclude, callable(getattr(cls, x)), x != "get_all", x in include if include is not None else True, ] ) } if __name__ == "__main__": import pprint pprint.pprint(SysInfo.get_all()) ``````{ end_of_file="muutils/sysinfo.py" } ``````{ path="muutils/tensor_info.py" } "get metadata about a tensor, mostly for `muutils.dbg`" from __future__ import annotations import numpy as np from typing import Union, Any, Literal, List, Dict, overload, Optional, TYPE_CHECKING if TYPE_CHECKING: from typing import TypedDict else: try: from typing import TypedDict except ImportError: from typing_extensions import TypedDict # Global color definitions COLORS: Dict[str, Dict[str, str]] = { "latex": { "range": r"\textcolor{purple}", "mean": r"\textcolor{teal}", "std": r"\textcolor{orange}", "median": r"\textcolor{green}", "warning": r"\textcolor{red}", "shape": r"\textcolor{magenta}", "dtype": r"\textcolor{gray}", "device": r"\textcolor{gray}", "requires_grad": r"\textcolor{gray}", "sparkline": r"\textcolor{blue}", "torch": r"\textcolor{orange}", "dtype_bool": r"\textcolor{gray}", "dtype_int": r"\textcolor{blue}", "dtype_float": r"\textcolor{red!70}", # 70% red intensity "dtype_str": r"\textcolor{red}", "device_cuda": r"\textcolor{green}", "reset": "", }, "terminal": { "range": "\033[35m", # purple "mean": "\033[36m", # cyan/teal "std": "\033[33m", # yellow/orange "median": "\033[32m", # green "warning": "\033[31m", # red "shape": "\033[95m", # bright magenta "dtype": "\033[90m", # gray "device": "\033[90m", # gray "requires_grad": "\033[90m", # gray "sparkline": "\033[34m", # blue "torch": "\033[38;5;208m", # bright orange "dtype_bool": "\033[38;5;245m", # medium grey "dtype_int": "\033[38;5;39m", # bright blue "dtype_float": "\033[38;5;167m", # softer red/coral "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green "reset": "\033[0m", }, "none": { "range": "", "mean": "", "std": "", "median": "", "warning": "", "shape": "", "dtype": "", "device": "", "requires_grad": "", "sparkline": "", "torch": "", "dtype_bool": "", "dtype_int": "", "dtype_float": "", "dtype_str": "", "device_cuda": "", "reset": "", }, } OutputFormat = Literal["unicode", "latex", "ascii"] class ArraySummarySettings(TypedDict): """Type definition for array_summary default settings.""" fmt: OutputFormat precision: int stats: bool shape: bool dtype: bool device: bool requires_grad: bool sparkline: bool sparkline_bins: int sparkline_logy: Optional[bool] colored: bool as_list: bool eq_char: str SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { "latex": { "range": r"\mathcal{R}", "mean": r"\mu", "std": r"\sigma", "median": r"\tilde{x}", "distribution": r"\mathbb{P}", "distribution_log": r"\mathbb{P}_L", "nan_values": r"\text{NANvals}", "warning": "!!!", "requires_grad": r"\nabla", "true": r"\checkmark", "false": r"\times", }, "unicode": { "range": "R", "mean": "μ", "std": "σ", "median": "x̃", "distribution": "ℙ", "distribution_log": "ℙ˪", "nan_values": "NANvals", "warning": "🚨", "requires_grad": "∇", "true": "✓", "false": "✗", }, "ascii": { "range": "range", "mean": "mean", "std": "std", "median": "med", "distribution": "dist", "distribution_log": "dist_log", "nan_values": "NANvals", "warning": "!!!", "requires_grad": "requires_grad", "true": "1", "false": "0", }, } "Symbols for different formats" SPARK_CHARS: Dict[OutputFormat, List[str]] = { "unicode": list(" ▁▂▃▄▅▆▇█"), "ascii": list(" _.-~=#"), "latex": list(" ▁▂▃▄▅▆▇█"), } "characters for sparklines in different formats" def array_info( A: Any, hist_bins: int = 5, ) -> Dict[str, Any]: """Extract statistical information from an array-like object. # Parameters: - `A : array-like` Array to analyze (numpy array or torch tensor) # Returns: - `Dict[str, Any]` Dictionary containing raw statistical information with numeric values """ result: Dict[str, Any] = { "is_tensor": None, "device": None, "requires_grad": None, "shape": None, "dtype": None, "size": None, "has_nans": None, "nan_count": None, "nan_percent": None, "min": None, "max": None, "range": None, "mean": None, "std": None, "median": None, "histogram": None, "bins": None, "status": None, } # Check if it's a tensor by looking at its class name # This avoids importing torch directly A_type: str = type(A).__name__ result["is_tensor"] = A_type == "Tensor" # Try to get device information if it's a tensor if result["is_tensor"]: try: result["device"] = str(getattr(A, "device", None)) except: # noqa: E722 pass # Convert to numpy array for calculations try: # For PyTorch tensors if result["is_tensor"]: # Check if tensor is on GPU is_cuda: bool = False try: is_cuda = bool(getattr(A, "is_cuda", False)) except: # noqa: E722 pass if is_cuda: try: # Try to get CPU tensor first cpu_tensor = getattr(A, "cpu", lambda: A)() except: # noqa: E722 A_np = np.array([]) else: cpu_tensor = A try: # For CPU tensor, just detach and convert detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() # pyright: ignore[reportPossiblyUnboundVariable] A_np = getattr(detached, "numpy", lambda: np.array([]))() except: # noqa: E722 A_np = np.array([]) else: # For numpy arrays and other array-like objects A_np = np.asarray(A) except: # noqa: E722 A_np = np.array([]) # Get basic information try: result["shape"] = A_np.shape result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) result["size"] = A_np.size result["requires_grad"] = getattr(A, "requires_grad", None) except: # noqa: E722 pass # If array is empty, return early if result["size"] == 0: result["status"] = "empty array" return result # Flatten array for statistics if it's multi-dimensional # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225 try: if len(A_np.shape) > 1: A_flat = A_np.flatten() # type: ignore[assignment] else: A_flat = A_np # type: ignore[assignment] except: # noqa: E722 A_flat = A_np # type: ignore[assignment] # Check for NaN values try: nan_mask = np.isnan(A_flat) result["nan_count"] = np.sum(nan_mask) result["has_nans"] = result["nan_count"] > 0 result_size: int = result["size"] # ty: ignore[invalid-assignment] if result_size > 0: result["nan_percent"] = (result["nan_count"] / result_size) * 100 except: # noqa: E722 pass # If all values are NaN, return early if result["has_nans"] and result["nan_count"] == result["size"]: result["status"] = "all NaN" return result # Calculate statistics try: if result["has_nans"]: result["min"] = float(np.nanmin(A_flat)) result["max"] = float(np.nanmax(A_flat)) result["mean"] = float(np.nanmean(A_flat)) result["std"] = float(np.nanstd(A_flat)) result["median"] = float(np.nanmedian(A_flat)) result["range"] = (result["min"], result["max"]) # Remove NaNs for histogram # TYPING: nan mask will def be bound on this branch, idk why it thinks the operator is bad A_hist = A_flat[~nan_mask] # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable] else: result["min"] = float(np.min(A_flat)) result["max"] = float(np.max(A_flat)) result["mean"] = float(np.mean(A_flat)) result["std"] = float(np.std(A_flat)) result["median"] = float(np.median(A_flat)) result["range"] = (result["min"], result["max"]) A_hist = A_flat # Calculate histogram data for sparklines if A_hist.size > 0: try: # TODO: handle bool tensors correctly # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to for compatibility. hist, bins = np.histogram(A_hist, bins=hist_bins) result["histogram"] = hist result["bins"] = bins except: # noqa: E722 pass result["status"] = "ok" except Exception as e: result["status"] = f"error: {str(e)}" return result SparklineFormat = Literal["unicode", "latex", "ascii"] def generate_sparkline( histogram: np.ndarray, format: SparklineFormat = "unicode", log_y: Optional[bool] = None, ) -> tuple[str, bool]: """Generate a sparkline visualization of the histogram. # Parameters: - `histogram : np.ndarray` Histogram data - `format : Literal["unicode", "latex", "ascii"]` Output format (defaults to `"unicode"`) - `log_y : bool|None` Whether to use logarithmic y-scale. `None` for automatic detection (defaults to `None`) # Returns: - `tuple[str, bool]` Sparkline visualization and whether log scale was used """ if histogram is None or len(histogram) == 0: return "", False # Get the appropriate character set chars: List[str] if format in SPARK_CHARS: chars = SPARK_CHARS[format] else: chars = SPARK_CHARS["ascii"] # automatic detection of log_y if log_y is None: # we bin the histogram values to the number of levels in our sparkline characters hist_hist = np.histogram(histogram, bins=len(chars))[0] # if every bin except the smallest (first) and largest (last) is empty, # then we should use the log scale. if those bins are nonempty, keep the linear scale if hist_hist[1:-1].max() > 0: log_y = False else: log_y = True # Handle log scale if log_y: # Add small value to avoid log(0) hist_data = np.log1p(histogram) else: hist_data = histogram # Normalize to character set range if hist_data.max() > 0: normalized = hist_data / hist_data.max() * (len(chars) - 1) else: normalized = np.zeros_like(hist_data) # Convert to characters spark = "" for val in normalized: idx = round(val) spark += chars[idx] return spark, log_y DEFAULT_SETTINGS: ArraySummarySettings = { "fmt": "unicode", "precision": 2, "stats": True, "shape": True, "dtype": True, "device": True, "requires_grad": True, "sparkline": False, "sparkline_bins": 5, "sparkline_logy": None, "colored": False, "as_list": False, "eq_char": "=", } def apply_color( text: str, color_key: str, colors: Dict[str, str], using_tex: bool ) -> str: if using_tex: return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text else: return ( f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text ) def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: """Colorize dtype string with specific colors for torch and type names.""" # Handle torch prefix type_part: str = dtype_str prefix_part: Optional[str] = None if "torch." in dtype_str: parts = dtype_str.split("torch.") if len(parts) == 2: prefix_part = apply_color("torch", "torch", colors, using_tex) type_part = parts[1] # Handle type coloring color_key: str = "dtype" if "bool" in dtype_str.lower(): color_key = "dtype_bool" elif "int" in dtype_str.lower(): color_key = "dtype_int" elif "float" in dtype_str.lower(): color_key = "dtype_float" type_colored: str = apply_color(type_part, color_key, colors, using_tex) if prefix_part: return f"{prefix_part}.{type_colored}" else: return type_colored def format_shape_colored( shape_val: Any, colors: Dict[str, str], using_tex: bool ) -> str: """Format shape with proper coloring for both 1D and multi-D arrays.""" def apply_color(text: str, color_key: str) -> str: if using_tex: return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text else: return ( f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text ) if len(shape_val) == 1: # For 1D arrays, still color the dimension value return apply_color(str(shape_val[0]), "shape") else: # For multi-D arrays, color each dimension return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")" def format_device_colored( device_str: str, colors: Dict[str, str], using_tex: bool ) -> str: """Format device string with CUDA highlighting.""" def apply_color(text: str, color_key: str) -> str: if using_tex: return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text else: return ( f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text ) if "cuda" in device_str.lower(): return apply_color(device_str, "device_cuda") else: return apply_color(device_str, "device") class _UseDefaultType: pass _USE_DEFAULT = _UseDefaultType() @overload def array_summary( array: Any, fmt: OutputFormat = ..., precision: int = ..., stats: bool = ..., shape: bool = ..., dtype: bool = ..., device: bool = ..., requires_grad: bool = ..., sparkline: bool = ..., sparkline_bins: int = ..., sparkline_logy: Optional[bool] = ..., colored: bool = ..., eq_char: str = ..., *, as_list: Literal[True], ) -> List[str]: ... @overload def array_summary( array: Any, fmt: OutputFormat = ..., precision: int = ..., stats: bool = ..., shape: bool = ..., dtype: bool = ..., device: bool = ..., requires_grad: bool = ..., sparkline: bool = ..., sparkline_bins: int = ..., sparkline_logy: Optional[bool] = ..., colored: bool = ..., eq_char: str = ..., as_list: Literal[False] = ..., ) -> str: ... @overload def array_summary( array: Any, fmt: OutputFormat = ..., precision: int = ..., stats: bool = ..., shape: bool = ..., dtype: bool = ..., device: bool = ..., requires_grad: bool = ..., sparkline: bool = ..., sparkline_bins: int = ..., sparkline_logy: Optional[bool] = ..., colored: bool = ..., eq_char: str = ..., as_list: bool = ..., ) -> Union[str, List[str]]: ... def array_summary( array: Any, fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT, precision: Union[int, _UseDefaultType] = _USE_DEFAULT, stats: Union[bool, _UseDefaultType] = _USE_DEFAULT, shape: Union[bool, _UseDefaultType] = _USE_DEFAULT, dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT, device: Union[bool, _UseDefaultType] = _USE_DEFAULT, requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT, sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT, sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT, sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT, colored: Union[bool, _UseDefaultType] = _USE_DEFAULT, eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT, as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT, ) -> Union[str, List[str]]: """Format array information into a readable summary. # Parameters: - `array` array-like object (numpy array or torch tensor) - `precision : int` Decimal places (defaults to `2`) - `format : Literal["unicode", "latex", "ascii"]` Output format (defaults to `{default_fmt}`) - `stats : bool` Whether to include statistical info (μ, σ, x̃) (defaults to `True`) - `shape : bool` Whether to include shape info (defaults to `True`) - `dtype : bool` Whether to include dtype info (defaults to `True`) - `device : bool` Whether to include device info for torch tensors (defaults to `True`) - `requires_grad : bool` Whether to include requires_grad info for torch tensors (defaults to `True`) - `sparkline : bool` Whether to include a sparkline visualization (defaults to `False`) - `sparkline_width : int` Width of the sparkline (defaults to `20`) - `sparkline_logy : bool|None` Whether to use logarithmic y-scale for sparkline (defaults to `None`) - `colored : bool` Whether to add color to output (defaults to `False`) - `as_list : bool` Whether to return as list of strings instead of joined string (defaults to `False`) # Returns: - `Union[str, List[str]]` Formatted statistical summary, either as string or list of strings """ if isinstance(fmt, _UseDefaultType): fmt = DEFAULT_SETTINGS["fmt"] if isinstance(precision, _UseDefaultType): precision = DEFAULT_SETTINGS["precision"] if isinstance(stats, _UseDefaultType): stats = DEFAULT_SETTINGS["stats"] if isinstance(shape, _UseDefaultType): shape = DEFAULT_SETTINGS["shape"] if isinstance(dtype, _UseDefaultType): dtype = DEFAULT_SETTINGS["dtype"] if isinstance(device, _UseDefaultType): device = DEFAULT_SETTINGS["device"] if isinstance(requires_grad, _UseDefaultType): requires_grad = DEFAULT_SETTINGS["requires_grad"] if isinstance(sparkline, _UseDefaultType): sparkline = DEFAULT_SETTINGS["sparkline"] if isinstance(sparkline_bins, _UseDefaultType): sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] if isinstance(sparkline_logy, _UseDefaultType): sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] if isinstance(colored, _UseDefaultType): colored = DEFAULT_SETTINGS["colored"] if isinstance(as_list, _UseDefaultType): as_list = DEFAULT_SETTINGS["as_list"] if isinstance(eq_char, _UseDefaultType): eq_char = DEFAULT_SETTINGS["eq_char"] array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) result_parts: List[str] = [] using_tex: bool = fmt == "latex" # Set color scheme based on format and colored flag colors: Dict[str, str] if colored: colors = COLORS["latex"] if using_tex else COLORS["terminal"] else: colors = COLORS["none"] # Get symbols for the current format symbols: Dict[str, str] = SYMBOLS[fmt] # Helper function to colorize text def colorize(text: str, color_key: str) -> str: if using_tex: return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text else: return ( f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text ) # Check if dtype is integer type dtype_str: str = array_data.get("dtype", "") is_int_dtype: bool = any( int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] ) # Format string for numbers float_fmt: str = f".{precision}f" # Handle error status or empty array if ( array_data["status"] in ["empty array", "all NaN", "unknown"] or array_data["size"] == 0 ): status = array_data["status"] result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) else: # Add NaN warning at the beginning if there are NaNs if array_data["has_nans"]: _percent: str = "\\%" if using_tex else "%" nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" result_parts.append(colorize(nan_str, "warning")) # Statistics if stats: for stat_key in ["mean", "std", "median"]: if array_data[stat_key] is not None: stat_str: str = f"{array_data[stat_key]:{float_fmt}}" stat_colored: str = colorize(stat_str, stat_key) result_parts.append(f"{symbols[stat_key]}={stat_colored}") # Range (min, max) if array_data["range"] is not None: min_val, max_val = array_data["range"] if is_int_dtype: min_str: str = f"{int(min_val):d}" max_str: str = f"{int(max_val):d}" else: min_str = f"{min_val:{float_fmt}}" max_str = f"{max_val:{float_fmt}}" min_colored: str = colorize(min_str, "range") max_colored: str = colorize(max_str, "range") range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" result_parts.append(range_str) # Add sparkline if requested if sparkline and array_data["histogram"] is not None: # this should return whether log_y is used or not and then we set the symbol accordingly spark, used_log = generate_sparkline( array_data["histogram"], format=fmt, log_y=sparkline_logy, ) if spark: spark_colored = colorize(spark, "sparkline") dist_symbol = ( symbols["distribution_log"] if used_log else symbols["distribution"] ) result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") # Add shape if requested if shape and array_data["shape"]: shape_val = array_data["shape"] shape_str = format_shape_colored(shape_val, colors, using_tex) result_parts.append(f"shape{eq_char}{shape_str}") # Add dtype if requested if dtype and array_data["dtype"]: dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) result_parts.append(f"dtype={dtype_colored}") # Add device if requested and it's a tensor with device info if device and array_data["is_tensor"] and array_data["device"]: device_colored = format_device_colored(array_data["device"], colors, using_tex) result_parts.append(f"device{eq_char}{device_colored}") # Add gradient info if requires_grad and array_data["is_tensor"]: bool_req_grad_symb: str = ( symbols["true"] if array_data["requires_grad"] else symbols["false"] ) result_parts.append( colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") ) # Return as list if requested, otherwise join with spaces if as_list: return result_parts else: joinchar: str = r" \quad " if using_tex else " " return joinchar.join(result_parts) ``````{ end_of_file="muutils/tensor_info.py" } ``````{ path="muutils/tensor_utils.py" } """utilities for working with tensors and arrays. notably: - `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types - `DTYPE_MAP` mapping string representations of types to their type - `TORCH_DTYPE_MAP` mapping string representations of types to torch types - `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match """ from __future__ import annotations import json import typing from typing import Any import jaxtyping import numpy as np import torch from muutils.dictmagic import dotlist_to_nested_dict # pylint: disable=missing-class-docstring TYPE_TO_JAX_DTYPE: dict[Any, Any] = { float: jaxtyping.Float, int: jaxtyping.Int, jaxtyping.Float: jaxtyping.Float, jaxtyping.Int: jaxtyping.Int, # bool bool: jaxtyping.Bool, jaxtyping.Bool: jaxtyping.Bool, np.bool_: jaxtyping.Bool, torch.bool: jaxtyping.Bool, # numpy float np.float16: jaxtyping.Float, np.float32: jaxtyping.Float, np.float64: jaxtyping.Float, np.half: jaxtyping.Float, np.single: jaxtyping.Float, np.double: jaxtyping.Float, # numpy int np.int8: jaxtyping.Int, np.int16: jaxtyping.Int, np.int32: jaxtyping.Int, np.int64: jaxtyping.Int, np.longlong: jaxtyping.Int, np.short: jaxtyping.Int, np.uint8: jaxtyping.Int, # torch float torch.float: jaxtyping.Float, torch.float16: jaxtyping.Float, torch.float32: jaxtyping.Float, torch.float64: jaxtyping.Float, torch.half: jaxtyping.Float, torch.double: jaxtyping.Float, torch.bfloat16: jaxtyping.Float, # torch int torch.int: jaxtyping.Int, torch.int8: jaxtyping.Int, torch.int16: jaxtyping.Int, torch.int32: jaxtyping.Int, torch.int64: jaxtyping.Int, torch.long: jaxtyping.Int, torch.short: jaxtyping.Int, } "dict mapping python, numpy, and torch types to `jaxtyping` types" # np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 # use try/except for backwards compatibility and type checker friendliness try: TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] except AttributeError: pass # numpy 2.0+ removed these deprecated aliases # TODO: add proper type annotations to this signature # TODO: maybe get rid of this altogether? # def jaxtype_factory( # name: str, # array_type: type, # default_jax_dtype: type[jaxtyping.Float | jaxtyping.Int | jaxtyping.Bool] = jaxtyping.Float, # legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, # ) -> type: # """usage: # ``` # ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # x: ATensor["dim1 dim2", np.float32] # ``` # """ # legacy_mode_ = ErrorMode.from_any(legacy_mode) # class _BaseArray: # """jaxtyping shorthand # (backwards compatible with older versions of muutils.tensor_utils) # default_jax_dtype = {default_jax_dtype} # array_type = {array_type} # """ # def __new__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: # raise TypeError("Type FArray cannot be instantiated.") # def __init_subclass__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: # raise TypeError(f"Cannot subclass {cls.__name__}") # @classmethod # def param_info(cls, params: typing.Union[str, tuple[Any, ...]]) -> str: # """useful for error printing""" # return "\n".join( # f"{k} = {v}" # for k, v in { # "cls.__name__": cls.__name__, # "cls.__doc__": cls.__doc__, # "params": params, # "type(params)": type(params), # }.items() # ) # @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] # def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: # type: ignore[misc] # # MyTensor["dim1 dim2"] # if isinstance(params, str): # return default_jax_dtype[array_type, params] # elif isinstance(params, tuple): # if len(params) != 2: # raise Exception( # f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" # ) # if isinstance(params[0], str): # # MyTensor["dim1 dim2", int] # return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] # elif isinstance(params[0], tuple): # legacy_mode_.process( # f"legacy type annotation was used:\n{cls.param_info(params) = }", # except_cls=Exception, # ) # # MyTensor[("dim1", "dim2"), int] # shape_anot: list[str] = list() # for x in params[0]: # if isinstance(x, str): # shape_anot.append(x) # elif isinstance(x, int): # shape_anot.append(str(x)) # elif isinstance(x, tuple): # shape_anot.append("".join(str(y) for y in x)) # else: # raise Exception( # f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" # ) # return TYPE_TO_JAX_DTYPE[params[1]][ # array_type, " ".join(shape_anot) # ] # else: # raise Exception( # f"unexpected type for params:\n{cls.param_info(params)}" # ) # _BaseArray.__name__ = name # if _BaseArray.__doc__ is None: # _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" # _BaseArray.__doc__ = _BaseArray.__doc__.format( # default_jax_dtype=repr(default_jax_dtype), # array_type=repr(array_type), # ) # return _BaseArray if typing.TYPE_CHECKING: # these class definitions are only used here to make pylint happy, # but they make mypy unhappy and there is no way to only run if not mypy # so, later on we have more ignores class ATensor(torch.Tensor): @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: raise NotImplementedError() class NDArray(torch.Tensor): @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: raise NotImplementedError() # ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] # NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: """convert numpy dtype to torch dtype""" if isinstance(dtype, torch.dtype): return dtype else: return torch.from_numpy(np.array(0, dtype=dtype)).dtype DTYPE_LIST: list[Any] = [ *[ bool, int, float, ], *[ # ---------- # pytorch # ---------- # floats torch.float, torch.float32, torch.float64, torch.half, torch.double, torch.bfloat16, # complex torch.complex64, torch.complex128, # ints torch.int, torch.int8, torch.int16, torch.int32, torch.int64, torch.long, torch.short, # simplest torch.uint8, torch.bool, ], *[ # ---------- # numpy # ---------- # floats np.float16, np.float32, np.float64, np.half, np.single, np.double, # complex np.complex64, np.complex128, # ints np.int8, np.int16, np.int32, np.int64, np.longlong, np.short, # simplest np.uint8, np.bool_, ], ] "list of all the python, numpy, and torch numerical types I could think of" # np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 try: DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] except AttributeError: pass # numpy 2.0+ removed these deprecated aliases DTYPE_MAP: dict[str, Any] = { **{str(x): x for x in DTYPE_LIST}, **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, } "mapping from string representations of types to their type" TORCH_DTYPE_MAP: dict[str, torch.dtype] = { key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() } "mapping from string representations of types to specifically torch types" # no idea why we have to do this, smh DTYPE_MAP["bool"] = np.bool_ TORCH_DTYPE_MAP["bool"] = torch.bool TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { "Adagrad": torch.optim.Adagrad, "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, "SparseAdam": torch.optim.SparseAdam, "Adamax": torch.optim.Adamax, "ASGD": torch.optim.ASGD, "LBFGS": torch.optim.LBFGS, "NAdam": torch.optim.NAdam, "RAdam": torch.optim.RAdam, "RMSprop": torch.optim.RMSprop, "Rprop": torch.optim.Rprop, "SGD": torch.optim.SGD, } def pad_tensor( tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 padded_length: int, pad_value: float = 0.0, rpad: bool = False, ) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 """pad a 1-d tensor on the left with pad_value to length `padded_length` set `rpad = True` to pad on the right instead""" temp: list[torch.Tensor] = [ torch.full( (padded_length - tensor.shape[0],), pad_value, dtype=tensor.dtype, device=tensor.device, ), tensor, ] if rpad: temp.reverse() return torch.cat(temp) def lpad_tensor( tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 ) -> torch.Tensor: """pad a 1-d tensor on the left with pad_value to length `padded_length`""" return pad_tensor(tensor, padded_length, pad_value, rpad=False) def rpad_tensor( tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 ) -> torch.Tensor: """pad a 1-d tensor on the right with pad_value to length `pad_length`""" return pad_tensor(tensor, pad_length, pad_value, rpad=True) def pad_array( array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 padded_length: int, pad_value: float = 0.0, rpad: bool = False, ) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 """pad a 1-d array on the left with pad_value to length `padded_length` set `rpad = True` to pad on the right instead""" temp: list[np.ndarray] = [ np.full( (padded_length - array.shape[0],), pad_value, dtype=array.dtype, ), array, ] if rpad: temp.reverse() return np.concatenate(temp) def lpad_array( array: np.ndarray, padded_length: int, pad_value: float = 0.0 ) -> np.ndarray: """pad a 1-d array on the left with pad_value to length `padded_length`""" return pad_array(array, padded_length, pad_value, rpad=False) def rpad_array( array: np.ndarray, pad_length: int, pad_value: float = 0.0 ) -> np.ndarray: """pad a 1-d array on the right with pad_value to length `pad_length`""" return pad_array(array, pad_length, pad_value, rpad=True) def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: """given a state dict or cache dict, compute the shapes and put them in a nested dict""" return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: """printable version of get_dict_shapes""" return json.dumps( dotlist_to_nested_dict( { k: str( tuple(v.shape) ) # to string, since indent wont play nice with tuples for k, v in d.items() } ), indent=2, ) class StateDictCompareError(AssertionError): """raised when state dicts don't match""" pass class StateDictKeysError(StateDictCompareError): """raised when state dict keys don't match""" pass class StateDictShapeError(StateDictCompareError): """raised when state dict shapes don't match""" pass class StateDictValueError(StateDictCompareError): """raised when state dict values don't match""" pass def compare_state_dicts( d1: dict[str, Any], d2: dict[str, Any], rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True, ) -> None: """compare two dicts of tensors # Parameters: - `d1 : dict` - `d2 : dict` - `rtol : float` (defaults to `1e-5`) - `atol : float` (defaults to `1e-8`) - `verbose : bool` (defaults to `True`) # Raises: - `StateDictKeysError` : keys don't match - `StateDictShapeError` : shapes don't match (but keys do) - `StateDictValueError` : values don't match (but keys and shapes do) """ # check keys match d1_keys: set[str] = set(d1.keys()) d2_keys: set[str] = set(d2.keys()) symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys) keys_diff_1: set[str] = d1_keys - d2_keys keys_diff_2: set[str] = d2_keys - d1_keys # sort sets for easier debugging symmetric_diff = set(sorted(symmetric_diff)) keys_diff_1 = set(sorted(keys_diff_1)) keys_diff_2 = set(sorted(keys_diff_2)) diff_shapes_1: str = ( string_dict_shapes({k: d1[k] for k in keys_diff_1}) if verbose else "(verbose = False)" ) diff_shapes_2: str = ( string_dict_shapes({k: d2[k] for k in keys_diff_2}) if verbose else "(verbose = False)" ) if not len(symmetric_diff) == 0: raise StateDictKeysError( f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" ) # check tensors match shape_failed: list[str] = list() vals_failed: list[str] = list() for k, v1 in d1.items(): v2 = d2[k] # check shapes first if not v1.shape == v2.shape: shape_failed.append(k) else: # if shapes match, check values if not torch.allclose(v1, v2, rtol=rtol, atol=atol): vals_failed.append(k) str_shape_failed: str = ( string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" ) str_vals_failed: str = ( string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" ) if not len(shape_failed) == 0: raise StateDictShapeError( f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" ) if not len(vals_failed) == 0: raise StateDictValueError( f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" ) ``````{ end_of_file="muutils/tensor_utils.py" } ``````{ path="muutils/timeit_fancy.py" } "`timeit_fancy` is just a fancier version of timeit with more options" from __future__ import annotations import pstats import timeit import cProfile from typing import Callable, Union, TypeVar, NamedTuple, Any import warnings from muutils.statcounter import StatCounter T_return = TypeVar("T_return") class FancyTimeitResult(NamedTuple): """return type of `timeit_fancy`""" timings: StatCounter return_value: T_return # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues] profile: Union[pstats.Stats, None] def timeit_fancy( cmd: Union[Callable[[], T_return], str], setup: Union[str, Callable[[], Any]] = lambda: None, repeats: int = 5, namespace: Union[dict[str, Any], None] = None, get_return: bool = True, do_profiling: bool = False, ) -> FancyTimeitResult: """ Wrapper for `timeit` to get the fastest run of a callable with more customization options. Approximates the functionality of the %timeit magic or command line interface in a Python callable. # Parameters - `cmd: Callable[[], T_return] | str` The callable to time. If a string, it will be passed to `timeit.Timer` as the `stmt` argument. - `setup: str` The setup code to run before `cmd`. If a string, it will be passed to `timeit.Timer` as the `setup` argument. - `repeats: int` The number of times to run `cmd` to get a reliable measurement. - `namespace: dict[str, Any]` Passed to `timeit.Timer` constructor. If `cmd` or `setup` use local or global variables, they must be passed here. See `timeit` documentation for details. - `get_return: bool` Whether to pass the value returned from `cmd`. If True, the return value will be appended in a tuple with execution time. This is for speed and convenience so that `cmd` doesn't need to be run again in the calling scope if the return values are needed. (default: `False`) - `do_profiling: bool` Whether to return a `pstats.Stats` object in addition to the time and return value. (default: `False`) # Returns `FancyTimeitResult`, which is a NamedTuple with the following fields: - `time: float` The time in seconds it took to run `cmd` the minimum number of times to get a reliable measurement. - `return_value: T|None` The return value of `cmd` if `get_return` is `True`, otherwise `None`. - `profile: pstats.Stats|None` A `pstats.Stats` object if `do_profiling` is `True`, otherwise `None`. """ timer: timeit.Timer = timeit.Timer(cmd, setup, globals=namespace) # Perform the timing times: list[float] = timer.repeat(repeats, 1) # Optionally capture the return value profile: pstats.Stats | None = None return_value: T_return | None = None if (get_return or do_profiling) and isinstance(cmd, str): warnings.warn( ( "Can't do profiling or get return value from `cmd` because it is a string." + " If you want to get the return value, pass a callable instead." ), UserWarning, ) if (get_return or do_profiling) and not isinstance(cmd, str): # Optionally perform profiling if do_profiling: profiler: cProfile.Profile = cProfile.Profile() profiler.enable() try: return_value = cmd() except TypeError as e: warnings.warn( f"Failed to get return value from `cmd` due to error (probably passing a string). will return `return_value=None`\n{e}", ) if do_profiling: # profiler is def bound here assert isinstance(profiler, cProfile.Profile) # pyright: ignore[reportPossiblyUnboundVariable] profiler.disable() profile = pstats.Stats(profiler).strip_dirs().sort_stats("cumulative") # reset the return value if it wasn't requested if not get_return: return_value = None return FancyTimeitResult( timings=StatCounter(times), # TYPING: Argument is incorrect: Expected `typing.TypeVar`, found `None | @Todo`tyinvalid-argument-type # no idea how to fix return_value=return_value, # type: ignore[invalid-argument-type] profile=profile, ) ``````{ end_of_file="muutils/timeit_fancy.py" } ``````{ path="muutils/validate_type.py" } """experimental utility for validating types in python, see `validate_type`""" from __future__ import annotations from inspect import signature, unwrap import types import typing import functools from typing import Any # this is also for python <3.10 compatibility _GenericAliasTypeNames: typing.List[str] = [ "GenericAlias", "_GenericAlias", "_UnionGenericAlias", "_BaseGenericAlias", ] _GenericAliasTypesList: list[Any] = [ getattr(typing, name, None) for name in _GenericAliasTypeNames ] GenericAliasTypes: tuple[Any, ...] = tuple( [t for t in _GenericAliasTypesList if t is not None] ) class IncorrectTypeException(TypeError): pass class TypeHintNotImplementedError(NotImplementedError): pass class InvalidGenericAliasError(TypeError): pass def _return_validation_except( return_val: bool, value: typing.Any, expected_type: typing.Any ) -> bool: if return_val: return True else: raise IncorrectTypeException( f"Expected {expected_type = } for {value = }", f"{type(value) = }", f"{type(value).__mro__ = }", f"{typing.get_origin(expected_type) = }", f"{typing.get_args(expected_type) = }", "\ndo --tb=long in pytest to see full trace", ) return False def _return_validation_bool(return_val: bool) -> bool: return return_val def validate_type( value: typing.Any, expected_type: typing.Any, do_except: bool = False ) -> bool: """Validate that a `value` is of the `expected_type` # Parameters - `value`: the value to check the type of - `expected_type`: the type to check against. Not all types are supported - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`) (default: `False`) # Returns - `bool`: `True` if the value is of the expected type, `False` otherwise. # Raises - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True` - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard """ if expected_type is typing.Any: return True # set up the return function depending on `do_except` _return_func: typing.Callable[[bool], bool] = ( # functools.partial doesn't hint the function signature functools.partial( # type: ignore[assignment] _return_validation_except, value=value, expected_type=expected_type ) if do_except else _return_validation_bool ) # handle None type (used in type hints like tuple[int, None]) if expected_type is None: return _return_func(value is None) # base type without args if isinstance(expected_type, type): try: # if you use args on a type like `dict[str, int]`, this will fail return _return_func(isinstance(value, expected_type)) except TypeError as e: if isinstance(e, IncorrectTypeException): raise e origin: typing.Any = typing.get_origin(expected_type) args: tuple[Any, ...] = typing.get_args(expected_type) # useful for debugging # print(f"{value = }, {expected_type = }, {origin = }, {args = }") UnionType = getattr(types, "UnionType", None) if (origin is typing.Union) or ( # this works in python <3.10 False if UnionType is None # return False if UnionType is not available else origin is UnionType # return True if UnionType is available ): return _return_func(any(validate_type(value, arg) for arg in args)) # generic alias, more complicated item_type: type if isinstance(expected_type, GenericAliasTypes): if origin is list: # no args if len(args) == 0: return _return_func(isinstance(value, list)) # incorrect number of args if len(args) != 1: raise InvalidGenericAliasError( f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) # check is list if not isinstance(value, list): return _return_func(False) # check all items in list are of the correct type item_type = args[0] return all(validate_type(item, item_type) for item in value) if origin is dict: # no args if len(args) == 0: return _return_func(isinstance(value, dict)) # incorrect number of args if len(args) != 2: raise InvalidGenericAliasError( f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) # check is dict if not isinstance(value, dict): return _return_func(False) # check all items in dict are of the correct type key_type: type = args[0] value_type: type = args[1] return _return_func( all( validate_type(key, key_type) and validate_type(val, value_type) for key, val in value.items() ) ) if origin is set: # no args if len(args) == 0: return _return_func(isinstance(value, set)) # incorrect number of args if len(args) != 1: raise InvalidGenericAliasError( f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) # check is set if not isinstance(value, set): return _return_func(False) # check all items in set are of the correct type item_type = args[0] return _return_func(all(validate_type(item, item_type) for item in value)) if origin is tuple: # no args if len(args) == 0: return _return_func(isinstance(value, tuple)) # check is tuple if not isinstance(value, tuple): return _return_func(False) # check correct number of items in tuple if len(value) != len(args): return _return_func(False) # check all items in tuple are of the correct type return _return_func( all(validate_type(item, arg) for item, arg in zip(value, args)) ) if origin is type: # no args if len(args) == 0: return _return_func(isinstance(value, type)) # incorrect number of args if len(args) != 1: raise InvalidGenericAliasError( f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) # check is type item_type = args[0] if item_type in value.__mro__: return _return_func(True) else: return _return_func(False) # TODO: Callables, etc. raise TypeHintNotImplementedError( f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", f"{origin = }, {args = }", f"\n{GenericAliasTypes = }", ) else: raise TypeHintNotImplementedError( f"Unsupported type hint {expected_type = } for {value = }", f"{origin = }, {args = }", f"\n{GenericAliasTypes = }", ) def get_fn_allowed_kwargs(fn: typing.Callable[..., Any]) -> typing.Set[str]: """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.""" try: fn = unwrap(fn) params = signature(fn).parameters except ValueError as e: fn_name: str = getattr(fn, "__name__", str(fn)) err_msg = f"Cannot retrieve signature for {fn_name = } {fn = }: {str(e)}" raise ValueError(err_msg) from e return { param.name for param in params.values() if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY) } ``````{ end_of_file="muutils/validate_type.py" } ``````{ path="CHANGELOG.md" } # Changelog ## [Unreleased] ### Fixed - **`muutils.json_serialize.util.dc_eq`**: Fixed docstring that incorrectly stated `except_when_field_mismatch` defaults to `True` (actual default is `False`), and that it raises `TypeError` (it actually raises `AttributeError`) - **`muutils.json_serialize.util.dc_eq`**: Updated flowchart in docstring to accurately reflect the control flow, including the missing `false_when_class_mismatch` decision branch ### Breaking Changes - **`muutils.logger`**: `Logger.log()` and `SimpleLogger.log()` now require keyword arguments for all parameters after `msg`. This change was made to fix type checker compatibility between the two classes. **Before:** ```python logger.log("message", -10) # lvl as positional arg ``` **After:** ```python logger.log("message", lvl=-10) # lvl as keyword arg ``` ``````{ end_of_file="CHANGELOG.md" } ``````{ path="LICENSE" } GNU GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU General Public License is a free, copyleft license for software and other kinds of works. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. We, the Free Software Foundation, use the GNU General Public License for most of our software; it applies also to any other work released this way by its authors. You can apply it to your programs, too. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. To protect your rights, we need to prevent others from denying you these rights or asking you to surrender the rights. Therefore, you have certain responsibilities if you distribute copies of the software, or if you modify it: responsibilities to respect the freedom of others. For example, if you distribute copies of such a program, whether gratis or for a fee, you must pass on to the recipients the same freedoms that you received. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. Developers that use the GNU GPL protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License giving you legal permission to copy, distribute and/or modify it. For the developers' and authors' protection, the GPL clearly explains that there is no warranty for this free software. For both users' and authors' sake, the GPL requires that modified versions be marked as changed, so that their problems will not be attributed erroneously to authors of previous versions. Some devices are designed to deny users access to install or run modified versions of the software inside them, although the manufacturer can do so. This is fundamentally incompatible with the aim of protecting users' freedom to change the software. The systematic pattern of such abuse occurs in the area of products for individuals to use, which is precisely where it is most unacceptable. Therefore, we have designed this version of the GPL to prohibit the practice for those products. If such problems arise substantially in other domains, we stand ready to extend this provision to those domains in future versions of the GPL, as needed to protect the freedom of users. Finally, every program is threatened constantly by software patents. States should not allow patents to restrict development and use of software on general-purpose computers, but in those that do, we wish to avoid the special danger that patents applied to a free program could make it effectively proprietary. To prevent this, the GPL assures that patents cannot be used to render the program non-free. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Use with the GNU Affero General Public License. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU Affero General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the special requirements of the GNU Affero General Public License, section 13, concerning interaction through a network will apply to the combination as such. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If the program does terminal interaction, make it output a short notice like this when it starts in an interactive mode: Copyright (C) This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, your program's commands might be different; for a GUI interface, you would use an "about box". You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU GPL, see . The GNU General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read . ``````{ end_of_file="LICENSE" } ``````{ path="README.md" } [![PyPI](https://img.shields.io/pypi/v/muutils)](https://pypi.org/project/muutils/) ![PyPI - Downloads](https://img.shields.io/pypi/dm/muutils) [![docs](https://img.shields.io/badge/docs-latest-blue)](https://miv.name/muutils) [![Checks](https://github.com/mivanit/muutils/actions/workflows/checks.yml/badge.svg)](https://github.com/mivanit/muutils/actions/workflows/checks.yml) [![Checks](https://github.com/mivanit/muutils/actions/workflows/make-docs.yml/badge.svg)](https://github.com/mivanit/muutils/actions/workflows/make-docs.yml) [![Coverage](docs/coverage/coverage.svg)](docs/coverage/html/) ![GitHub commits](https://img.shields.io/github/commit-activity/t/mivanit/muutils) ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/mivanit/muutils) ![GitHub closed pull requests](https://img.shields.io/github/issues-pr-closed/mivanit/muutils) ![code size, bytes](https://img.shields.io/github/languages/code-size/mivanit/muutils) `muutils`, stylized as "$\mu$utils" or "μutils", is a collection of miscellaneous python utilities, meant to be small and with no dependencies outside of standard python. # installation PyPi: [muutils](https://pypi.org/project/muutils/) ``` pip install muutils ``` Optional dependencies: ``` pip install muutils[array] # numpy, torch, jaxtyping -- for mlutils, tensor_utils, tensor_info, ml, json_serialize array features pip install muutils[notebook] # ipython -- for nbutils.configure_notebook pip install muutils[parallel] # multiprocess, tqdm -- for parallel processing with progress pip install muutils[web] # weasyprint -- for web/html_to_pdf ``` # documentation [https://miv.name/muutils](https://miv.name/muutils) # modules | Module | Description | |--------|-------------| | [`statcounter`](https://miv.name/muutils/muutils/statcounter.html) | Extension of `collections.Counter` with smart stats computation (mean, variance, percentiles) | | [`dictmagic`](https://miv.name/muutils/muutils/dictmagic.html) | Dictionary utilities: dotlist conversion, `DefaulterDict`, tensor dict condensing | | [`kappa`](https://miv.name/muutils/muutils/kappa.html) | Anonymous getitem (`Kappa(lambda x: x**2)[2]` returns `4`) | | [`sysinfo`](https://miv.name/muutils/muutils/sysinfo.html) | System information collection for logging | | [`misc`](https://miv.name/muutils/muutils/misc.html) | Utilities: `stable_hash`, `list_join`/`list_split`, filename sanitization, `freeze` | | [`interval`](https://miv.name/muutils/muutils/interval.html) | Mathematical intervals (open/closed/half-open) with containment, clamping, set operations | | [`errormode`](https://miv.name/muutils/muutils/errormode.html) | Enum-based error handling (raise/warn/log/ignore) | | [`validate_type`](https://miv.name/muutils/muutils/validate_type.html) | Runtime type validation for basic and generic types | | [`console_unicode`](https://miv.name/muutils/muutils/console_unicode.html) | Safe console output with Unicode/ASCII fallback | | [`spinner`](https://miv.name/muutils/muutils/spinner.html) | Animated spinners with elapsed time and status updates | | [`timeit_fancy`](https://miv.name/muutils/muutils/timeit_fancy.html) | Enhanced timing with multiple runs, profiling, and statistics | | [`dbg`](https://miv.name/muutils/muutils/dbg.html) | Debug printing inspired by Rust's `dbg!` macro | | [`collect_warnings`](https://miv.name/muutils/muutils/collect_warnings.html) | Context manager to capture and summarize warnings | | [`parallel`](https://miv.name/muutils/muutils/parallel.html) | Simplified parallel processing with progress bars | | [`jsonlines`](https://miv.name/muutils/muutils/jsonlines.html) | Simple `jsonl` file reading/writing | | [`group_equiv`](https://miv.name/muutils/muutils/group_equiv.html) | Group elements by equivalence relation (non-transitive) | | [`json_serialize`](https://miv.name/muutils/muutils/json_serialize.html) | Serialize arbitrary Python objects to JSON (works with [ZANJ](https://github.com/mivanit/ZANJ/)) | | [`nbutils`](https://miv.name/muutils/muutils/nbutils.html) | Jupyter utilities: notebook conversion, configuration, mermaid/TeX display | | [`math`](https://miv.name/muutils/muutils/math.html) | Binning functions and matrix power computation | | [`cli`](https://miv.name/muutils/muutils/cli.html) | CLI utilities: boolean argument parsing, flag actions | | [`web`](https://miv.name/muutils/muutils/web.html) | HTML asset inlining for standalone documents | | [`logger`](https://miv.name/muutils/muutils/logger.html) | *(deprecated)* Logging framework, use [`trnbl`](https://github.com/mivanit/trnbl) instead | | [`mlutils`](https://miv.name/muutils/muutils/mlutils.html) | ML pipeline: device detection, seeding, checkpoints *(requires `array`)* | | [`tensor_utils`](https://miv.name/muutils/muutils/tensor_utils.html) | PyTorch/numpy type conversions *(requires `array`)* | | [`tensor_info`](https://miv.name/muutils/muutils/tensor_info.html) | Tensor metadata extraction and formatting *(requires `array`)* | | [`ml`](https://miv.name/muutils/muutils/ml.html) | CUDA memory monitoring *(requires `array`)* | # [`ZANJ`](https://github.com/mivanit/ZANJ/) ZANJ is a human-readable and simple format for ML models, datasets, and arbitrary objects. It's built around having a zip file with `json` and `npy` files, and has been spun off into its [own project](https://github.com/mivanit/ZANJ/). There are a couple work-in-progress utilities in [`_wip`](https://github.com/mivanit/muutils/tree/main/muutils/_wip/) that aren't ready for anything, but nothing in this repo is suitable for production. Use at your own risk! ``````{ end_of_file="README.md" } ``````{ path="TODO.md" } # Type Error Fixing TODO ## Instructions 1. Read the entire file `.meta/typing-summary.txt` to get an overview of the current type errors in the codebase. 2. Read the type checker output files: - `.meta/.type-errors/mypy.txt` - `.meta/.type-errors/basedpyright.txt` - `.meta/.type-errors/ty.txt` NOTE: the files are many thousands of lines, you will have to pick a *random* few hundred lines to read. it is important that you pick a random set of lines, since you will be working in parallel with other Claude instances, and we want to avoid everyone working on the same errors. 3. Decide on a good fix to make. For example, you might pick: - the fix with the best **"number of errors / complexity of change" ratio** - a fix that gets us closer to removing an entire category of errors - a fix that gets us closer to having no errors in a specific file (**FOCUS ON THIS!**) 4. Implement that fix 5. run type checking only on the specific file you are changing to verify that the errors are fixed. use `uv run `, not `python -m` # Guidelines: - make sure all type hints are python>=3.8 compatible - always err on the side of STRICTER type hints! - try to avoid breaking changes. check with the user before making breaking changes. if breaking changes are necessary, and the user agrees, make sure to document them properly and add them to CHANGELOG.md ``````{ end_of_file="TODO.md" } ``````{ path="makefile" processed_with="makefile_recipes" } # first/default target is help .PHONY: default default: help ... # download makefile helper scripts from GitHub # uses curl to fetch scripts from the template repository # override version: make self-setup-scripts SCRIPTS_VERSION=v0.5.0 .PHONY: self-setup-scripts self-setup-scripts: @echo "downloading makefile scripts (version: $(SCRIPTS_VERSION))" ... # this recipe is weird. we need it because: # - a one liner for getting the version with toml is unwieldy, and using regex is fragile # - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues # - trying to write to the file inside the `gen-version-info` recipe doesn't work, # shell eval happens before our `python ...` gets run and `cat` doesn't see the new file .PHONY: write-proj-version write-proj-version: ... # gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version # uses just `python` for everything except getting the python version. no echo here, because this is "private" .PHONY: gen-version-info gen-version-info: write-proj-version ... # getting commit log since the tag specified in $(LAST_VERSION_FILE) # will write to $(COMMIT_LOG_FILE) # when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process) # no echo here, because this is "private" .PHONY: gen-commit-log gen-commit-log: gen-version-info ... # force the version info to be read, printing it out # also force the commit log to be generated, and cat it out .PHONY: version version: gen-commit-log @echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)" ... .PHONY: setup setup: self-setup-scripts dep-check @echo "download scripts and sync dependencies" ... .PHONY: dep-check-torch dep-check-torch: @echo "see if torch is installed, and which CUDA version and devices it sees" ... # sync dependencies and export to requirements.txt files # - syncs all extras and groups with uv (including dev dependencies) # - compiles bytecode for faster imports # - exports to requirements.txt files per tool.uv-exports.exports config # configure via pyproject.toml:[tool.uv-exports]: # [tool.uv-exports] # exports = [ # { name = "base", extras = [], groups = [] }, # base package deps only # { name = "dev", extras = [], groups = ["dev"] }, # dev dependencies # { name = "all", extras = ["all"], groups = ["dev"] } # everything # ] .PHONY: dep dep: @echo "syncing and exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'" ... .PHONY: dep-compile dep-compile: @echo "syncing dependencies with bytecode compilation" ... # verify that requirements.txt files match current dependencies # - exports deps to temp directory # - diffs temp against existing requirements files # - FAILS if any differences found (means you need to run `make dep`) # useful in CI to catch when pyproject.toml changed but requirements weren't regenerated .PHONY: dep-check dep-check: @echo "Checking that exported requirements are up to date" ... .PHONY: dep-clean dep-clean: @echo "clean up lock files, .venv, and requirements files" ... # extra tests with python >=3.10 type hints .PHONY: gen-extra-tests gen-extra-tests: ... # format code AND auto-fix linting issues # performs TWO operations: reformats code, then auto-fixes safe linting issues # configure in pyproject.toml:[tool.ruff] .PHONY: format format: @echo "format the source code" ... # runs ruff to check if the code is formatted correctly .PHONY: format-check format-check: @echo "check if the source code is formatted correctly" ... # runs type checks with configured checkers # set TYPE_CHECKERS to customize which checkers run (e.g., TYPE_CHECKERS=mypy,basedpyright) # set TYPING_OUTPUT_DIR to save outputs to files (used by typing-summary) # returns exit code 1 if any checker fails # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # special casing mypy .PHONY: typing typing: gen-extra-tests @echo "running type checks" ... # save type check outputs and generate detailed breakdown # outputs are saved to $(TYPE_ERRORS_DIR)/*.txt # summary is generated to $(TYPING_SUMMARY_FILE) .PHONY: typing-summary typing-summary: gen-extra-tests @echo "running type checks and saving to $(TYPE_ERRORS_DIR)/" ... # run tests with pytest # you can pass custom args. for example: # make test PYTEST_OPTIONS="--maxfail=1 -x" # pytest config in pyproject.toml:[tool.pytest.ini_options] .PHONY: test test: gen-extra-tests @echo "running tests" ... .PHONY: check check: format-check test typing @echo "run format checks, tests, and typing checks" ... # generates a whole tree of documentation in html format. # see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info .PHONY: docs-html docs-html: @echo "generate html docs" ... # instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`. # this is useful if you want to have a copy that you can grep/search, but those docs are much messier. .PHONY: docs-md docs-md: @echo "generate combined (single-file) docs in markdown" ... # generate coverage reports from test results # WARNING: if .coverage file not found, will automatically run `make test` first # - generates text report: $(COVERAGE_REPORTS_DIR)/coverage.txt # - generates SVG badge: $(COVERAGE_REPORTS_DIR)/coverage.svg # - generates HTML report: $(COVERAGE_REPORTS_DIR)/html/ # - removes .gitignore from html dir (we publish coverage with docs) .PHONY: cov cov: @echo "generate coverage reports" ... # runs the coverage report, then the docs, then the combined docs .PHONY: docs docs: cov docs-html docs-md todo lmcat @echo "generate all documentation and coverage reports" ... # remove generated documentation files, but preserve resources # - removes all docs except those in DOCS_RESOURCES_DIR # - preserves files/patterns specified in pyproject.toml config # - distinct from `make clean` (which removes temp build files, not docs) # configure via pyproject.toml:[tool.makefile.docs]: # [tool.makefile.docs] # output_dir = "docs" # must match DOCS_DIR in makefile # no_clean = [ # files/patterns to preserve when cleaning # "resources/**", # "*.svg", # "*.css" # ] .PHONY: docs-clean docs-clean: @echo "remove generated docs except resources" ... # get all TODO's from the code # configure via pyproject.toml:[tool.makefile.inline-todo]: # [tool.makefile.inline-todo] # search_dir = "." # directory to search for TODOs # out_file_base = "docs/other/todo-inline" # output file path (without extension) # context_lines = 2 # lines of context around each TODO # extensions = ["py", "md"] # file extensions to search # tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "DOC"] # tags to look for # exclude = ["docs/**", ".venv/**", "scripts/get_todos.py"] # patterns to exclude # branch = "main" # git branch for URLs # # repo_url = "..." # repository URL (defaults to [project.urls.{repository,github}]) # # template_md = "..." # custom jinja2 template for markdown output # # template_issue = "..." # custom format string for issues # # template_html_source = "..." # custom html template path # tag_label_map = { "BUG" = "bug", "TODO" = "enhancement", "DOC" = "documentation" } # mapping of tags to GitHub issue labels .PHONY: todo todo: @echo "get all TODO's from the code" ... .PHONY: lmcat-tree lmcat-tree: @echo "show in console the lmcat tree view" ... .PHONY: lmcat lmcat: @echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]" ... # verify git is ready for publishing # REQUIRES: # - current branch must be $(PUBLISH_BRANCH) # - no uncommitted changes (git status --porcelain must be empty) # EXITS with error if either condition fails .PHONY: verify-git verify-git: @echo "checking git status" ... # build package distribution files # creates wheel (.whl) and source distribution (.tar.gz) in dist/ .PHONY: build build: @echo "build the package" ... # publish package to PyPI and create git tag # PREREQUISITES: # - must be on $(PUBLISH_BRANCH) branch with clean git status (verified by verify-git) # - must have $(PYPI_TOKEN_FILE) with your PyPI token # - version in pyproject.toml must be different from $(LAST_VERSION_FILE) # PROCESS: # 1. runs checks, validates version, builds package, verifies git clean # 2. prompts for version confirmation (you can edit $(COMMIT_LOG_FILE) at this point) # 3. creates git commit updating $(LAST_VERSION_FILE) # 4. creates annotated git tag with commit log as description # 5. pushes tag to origin # 6. uploads to PyPI via twine .PHONY: publish publish: check version build verify-git @echo "Ready to publish $(PROJ_VERSION) to PyPI" ... # cleans up temporary files: # - caches: .mypy_cache, .ruff_cache, .pytest_cache, .coverage # - build artifacts: dist/, build/, *.egg-info # - test temp files: $(TESTS_TEMP_DIR) # - __pycache__ directories and *.pyc/*.pyo files in $(PACKAGE_NAME), $(TESTS_DIR), $(DOCS_DIR) # uses `-` prefix on find commands to continue even if directories don't exist # distinct from `make docs-clean`, which removes generated documentation # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # added cleanup of generated test file .PHONY: clean clean: @echo "clean up temporary files" ... # remove all generated/build files including .venv # runs: clean + docs-clean + dep-clean # removes .venv, uv.lock, requirements.txt files, generated docs, build artifacts # run `make dep` after this to reinstall dependencies .PHONY: clean-all clean-all: clean docs-clean dep-clean @echo "clean up all temporary files, dep files, venv, and generated docs" ... .PHONY: info info: gen-version-info @echo "# makefile variables" ... .PHONY: info-long info-long: info @echo "# other variables" ... # Smart help command: shows general help, or detailed info about specific targets # Usage: # make help - shows general help (list of targets + makefile variables) # make help="test" - shows detailed info about the 'test' recipe # make HELP="test clean" - shows detailed info about multiple recipes # make h=* - shows detailed info about all recipes (wildcard expansion) # make H="test" - same as HELP (case variations supported) # # All variations work: help/HELP/h/H with values like "foo", "foo bar", "*", "--all" .PHONY: help help: ... ``````{ end_of_file="makefile" } ``````{ path="pyproject.toml" } # metadata # ================================================== [project] name = "muutils" version = "0.9.0" description = "miscellaneous python utilities" readme = "README.md" requires-python = ">=3.8" license = { text = "GPL-3.0-only" } authors = [ { name = "mivanit", email = "mivanits@umich.edu" } ] classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Development Status :: 4 - Beta", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", "Topic :: Utilities", "Typing :: Typed", ] dependencies = [] # no required deps! [project.urls] Homepage = "https://miv.name/muutils" Repository = "https://github.com/mivanit/muutils" Documentation = "https://miv.name/muutils/" Issues = "https://github.com/mivanit/muutils/issues" # dependencies # ================================================== [project.optional-dependencies] array = [ "numpy>=1.24.4; python_version < '3.9'", "numpy>1.24.4; python_version >= '3.9'", "torch>=1.13.1,<2.5.0; python_version < '3.9'", "torch>=1.13.1; python_version >= '3.9' and python_version < '3.13'", "torch>=2.5.0; python_version >= '3.13'", "jaxtyping>=0.2.12", ] # special group for CI, where we install cpu torch separately array_no_torch = [ "numpy>=1.24.4; python_version < '3.9'", "numpy>1.24.4; python_version >= '3.9'", "jaxtyping>=0.2.12", ] notebook = [ "ipython>=8.0.0", ] parallel = [ "multiprocess>=0.70.17", "tqdm>=4.67.1", ] web = [ "weasyprint>=60.0", ] [dependency-groups] dev = [ # typing "mypy>=1.0.1; python_version < '3.9'", "mypy>=1.15; python_version >= '3.9'", "typing-extensions; python_version < '3.11'", "beartype>=0.14.1", "ty>=0.0.17", "basedpyright", "pandas-stubs>=2.0.0", "types-tqdm", # tests & coverage "pytest>=8.2.2", "pytest-cov>=4.1.0", "coverage-badge>=1.1.0", "setuptools>=78.1.1; python_version >= '3.9'", # https://github.com/mivanit/muutils/security/dependabot/31 # for testing plotting and notebooks "ipykernel", "jupyter", # for jupyter "h11>=0.16.0", # https://github.com/mivanit/muutils/security/dependabot/23 "tornado>=6.5; python_version >= '3.9'", # https://github.com/mivanit/muutils/security/dependabot/33 # plotting "pandas", "matplotlib>=3.0.0", "plotly>=5.0.0", "beautifulsoup4", # generating docs "pdoc>=14.6.0", # https://github.com/mivanit/muutils/security/dependabot/7 "jinja2>=3.1.6", # lmcat -- a custom library. not exactly docs, but lets an LLM see all the code "lmcat>=0.2.0; python_version >= '3.11'", # tomli since no tomlib in python < 3.11 "tomli>=2.1.0; python_version < '3.11'", # twine dep "twine", ] lint = [ # lint "pycln>=2.1.3", "ruff>=0.4.8", ] # build system and tooling configuration # ================================================== [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.pytest.ini_options] filterwarnings = [ "ignore::muutils.nbutils.configure_notebook.UnknownFigureFormatWarning", # don't show warning for unknown figure format "ignore::muutils.nbutils.configure_notebook.PlotlyNotInstalledWarning", # don't show warning for missing plotly "ignore::muutils.json_serialize.serializable_dataclass.ZanjMissingWarning", # don't show warning for missing zanj (can't have as a dep since zanj depends on muutils) "ignore: PEP 484 type hint*:beartype.roar._roarwarn.BeartypeDecorHintPep585DeprecationWarning", ] addopts = "--jaxtyping-packages=beartype.beartype" [tool.coverage.run] omit = ["*/_remote_module_non_scriptable*"] [tool.ruff] # Exclude the directories specified in the global excludes exclude = ["tests/input_data", "tests/junk_data", "_wip/", ".meta/scripts/"] [tool.ruff.lint.per-file-ignores] "muutils/tensor_info.py" = [ "E701", # multiple statements on one line (colon) ] "tests/unit/math/test_matrix_powers_torch.py" = [ "F722", # jaxtyping stuff ] "muutils/math/matrix_powers.py" = [ "F722", # jaxtyping stuff ] [tool.basedpyright] # TODO: change this to strict eventually typeCheckingMode = "standard" # file include/exclude include = ["muutils", "tests"] exclude = [ "tests/input_data", "tests/junk_data", "tests/_temp", "_wip", ".meta/scripts/", ".venv", ] # rules reportConstantRedefinition = false # I always use all caps for globals, not just consts reportDeprecated = false # this library is backwards compatible back to 3.8, so we are using lots of deprecated stuff reportUnsupportedDunderAll = false # we use __all__ a lot for docs stuff reportExplicitAny = false # we allow Any in many places. if it's there, it's intentional [tool.ty.src] exclude = [ "tests/input_data/", "tests/junk_data/", "tests/_temp/", "tests/benchmark_parallel.py", "_wip/", ".meta/scripts/", ] # TODO: remove this once we clean up the `# type: ignore` comments that are needed # for mypy/pyright but not for ty. See https://docs.astral.sh/ty/configuration/ [tool.ty.rules] # mostly for version compatibility unused-ignore-comment = "ignore" unused-type-ignore-comment = "ignore" ignore-comment-unknown-rule = "ignore" # [[tool.ty.overrides]] # include = ["tests/unit/json_serialize/test_json_serialize.py", "tests/unit/test_jsonlines.py"] # rules = {invalid-argument-type = "ignore", invalid-key = "ignore"} [tool.mypy] files = ["muutils", "tests"] exclude = [ # tests "tests/input_data", "tests/junk_data", "tests/_temp/", "tests/benchmark_parallel.py", # wip stuff "_wip/", # not our problem ".meta/scripts/", ] show_error_codes = true # we disable this in the makefile for old versions check_untyped_defs = true [tool.lmcat] output = "docs/other/lmcat.txt" # changing this might mean it wont be accessible from the docs ignore_patterns = [ "docs/**", ".venv/**", ".git/**", ".meta/**", "uv.lock", ".ruff_cache/**", ".github/ISSUE_TEMPLATE/**", "_wip/**", "sweep.yaml", # there are... a lot of tests. we usually dont need to put these in lmcat "tests/**", ] [tool.lmcat.glob_process] "[mM]akefile" = "makefile_recipes" # [tool.makefile] # ================================================================= [tool.makefile.docs] warnings_ignore = [ "Error parsing type annotation .* for muutils\\..*\\. Import of np failed:", "Error parsing type annotation .* for muutils\\..*\\. Import of JsonSerializer failed:", "Error parsing type annotation .* for muutils\\..*\\. Import of StatCounter failed:", "Error parsing type annotation .* for muutils\\..*\\. Import of Union failed:" ] # Custom export configurations [tool.makefile.uv-exports] args = [ "--no-hashes" ] exports = [ # no groups, no extras, just the base dependencies { name = "base", groups = false, extras = false }, # all extras but no groups { name = "extras", groups = false, extras = true }, # include the dev group (this is the default behavior) { name = "dev", groups = true }, # only the lint group -- custom options for this { name = "lint", options = ["--only-group", "lint"] }, # all groups and extras { name = "all", filename="requirements.txt", groups = true, extras=true }, # all groups and extras, a different way { name = "all", groups = true, options = ["--all-extras"] }, ] [tool.makefile.inline-todo] search_dir = "." out_file_base = "docs/other/todo-inline.md" # changing this might mean it wont be accessible from the docs context_lines = 5 extensions = ["py", "md"] tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "NOTE"] exclude = [ "docs/**", ".venv/**", "scripts/get_todos.py", "_wip/**", ] [tool.inline-todo.tag_label_map] NOTE = "documentation" CRIT = "bug" TODO = "enhancement" FIXME = "bug" BUG = "bug" HACK = "enhancement" ``````{ end_of_file="pyproject.toml" }