docs for
muutilsv0.8.12
muutils, stylized as “μutils” or “μutils”, is a collection
of miscellaneous python utilities, meant to be small and with no
dependencies outside of standard python.
PyPi: muutils
pip install muutils
Note that for using mlutils, tensor_utils,
nbutils.configure_notebook, or the array serialization
features of json_serialize, you will need to install with
optional array dependencies:
pip install muutils[array]
hosted html docs: https://miv.name/muutils
statcounteran extension of collections.Counter that provides
“smart” computation of stats (mean, variance, median, other percentiles)
from the counter object without using
Counter.elements()
dictmagichas utilities for working with dictionaries, like:
python >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} >>> nested_dict_to_dotlist({'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}) {'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}DefaulterDict which works like a
defaultdict but can generate the default value based on the
keycondense_tensor_dict takes a dict of dotlist-tensors
and gives a more human-readable summary:
python >>> model = MyGPT() >>> print(condense_tensor_dict(model.named_parameters(), 'yaml'))
yaml embed: W_E: (50257, 768) pos_embed: W_pos: (1024, 768) blocks: '[0-11]': attn: '[W_Q, W_K, W_V]': (12, 768, 64) W_O: (12, 64, 768) '[b_Q, b_K, b_V]': (12, 64) b_O: (768,) <...>kappaAnonymous gettitem, so you can do things like
>>> k = Kappa(lambda x: x**2)
>>> k[2]
4sysinfoutility for getting a bunch of system information. useful for logging.
misc:contains a few utilities: - stable_hash() uses
hashlib.sha256 to compute a hash of an object that is
stable across runs of python - list_join and
list_split which behave like str.join and
str.split but for lists - sanitize_fname and
dict_to_filename for simplifying the creation of unique
filename - shorten_numerical_to_str() and
str_to_numeric turns numbers like 123456789
into "123M" and back - freeze, which prevents
an object from being modified. Also see gelidum
nbutilscontains utilities for working with jupyter notebooks, such as:
json_serializea tool for serializing and loading arbitrary python objects into
json. plays nicely with ZANJ
tensor_utilscontains minor utilities for working with pytorch tensors and numpy arrays, mostly for making type conversions easier
group_equivgroups elements from a sequence according to a given equivalence relation, without assuming that the equivalence relation obeys the transitive property
jsonlinesan extremely simple utility for reading/writing jsonl
files
ZANJis a human-readable and simple format for ML models, datasets, and
arbitrary objects. It’s build around having a zip file with
json and npy files, and has been spun off into
its own project.
There are a couple work-in-progress utilities in _wip
that aren’t ready for anything, but nothing in this repo is suitable for
production. Use at your own risk!
json_serializeloggermathmiscnbutilswebcollect_warningsconsole_unicodedbgdictmagicerrormodegroup_equivintervaljsonlineskappamlutilsparallelspinnerstatcountersysinfotensor_infotensor_utilstimeit_fancyvalidate_typemuutils
muutils, stylized as “μutils” or “μutils”, is a collection
of miscellaneous python utilities, meant to be small and with no
dependencies outside of standard python.
PyPi: muutils
pip install muutils
Note that for using mlutils, tensor_utils,
nbutils.configure_notebook, or the array serialization
features of json_serialize, you will need to install with
optional array dependencies:
pip install muutils[array]
hosted html docs: https://miv.name/muutils
statcounteran extension of collections.Counter that provides
“smart” computation of stats (mean, variance, median, other percentiles)
from the counter object without using
Counter.elements()
dictmagichas utilities for working with dictionaries, like:
python >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} >>> nested_dict_to_dotlist({'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}) {'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}DefaulterDict which works like a
defaultdict but can generate the default value based on the
keycondense_tensor_dict takes a dict of dotlist-tensors
and gives a more human-readable summary:
python >>> model = MyGPT() >>> print(condense_tensor_dict(model.named_parameters(), 'yaml'))
yaml embed: W_E: (50257, 768) pos_embed: W_pos: (1024, 768) blocks: '[0-11]': attn: '[W_Q, W_K, W_V]': (12, 768, 64) W_O: (12, 64, 768) '[b_Q, b_K, b_V]': (12, 64) b_O: (768,) <...>kappaAnonymous gettitem, so you can do things like
>>> k = Kappa(lambda x: x**2)
>>> k[2]
4sysinfoutility for getting a bunch of system information. useful for logging.
misc:contains a few utilities: - stable_hash() uses
hashlib.sha256 to compute a hash of an object that is
stable across runs of python - list_join and
list_split which behave like str.join and
str.split but for lists - sanitize_fname and
dict_to_filename for simplifying the creation of unique
filename - shorten_numerical_to_str() and
str_to_numeric turns numbers like 123456789
into "123M" and back - freeze, which prevents
an object from being modified. Also see gelidum
nbutilscontains utilities for working with jupyter notebooks, such as:
json_serializea tool for serializing and loading arbitrary python objects into
json. plays nicely with ZANJ
tensor_utilscontains minor utilities for working with pytorch tensors and numpy arrays, mostly for making type conversions easier
group_equivgroups elements from a sequence according to a given equivalence relation, without assuming that the equivalence relation obeys the transitive property
jsonlinesan extremely simple utility for reading/writing jsonl
files
ZANJis a human-readable and simple format for ML models, datasets, and
arbitrary objects. It’s build around having a zip file with
json and npy files, and has been spun off into
its own project.
There are a couple work-in-progress utilities in _wip
that aren’t ready for anything, but nothing in this repo is suitable for
production. Use at your own risk!
docs for
muutilsv0.8.12
muutils.collect_warningsclass CollateWarnings(contextlib.AbstractContextManager['CollateWarnings']):Capture every warning issued inside a with block and
print a collated summary when the block exits.
Internally this wraps
warnings.catch_warnings(record=True) so that all warnings
raised in the block are recorded. When the context exits, identical
warnings are grouped and (optionally) printed with a user-defined
format.
print_on_exit : bool Whether to print the summary
when the context exits (defaults to True)
fmt : str Format string used for printing each line
of the summary. Available fields are:
{count} : number of occurrences{filename} : file where the warning originated{lineno} : line number{category} : warning class name{message} : warning message text(defaults to
"({count}x) {filename}:{lineno} {category}: {message}")
CollateWarnings The context-manager instance. After
exit, the attribute counts holds a mapping
{(filename, lineno, category, message): count}>>> import warnings
>>> with CollateWarnings() as cw:
... warnings.warn("deprecated", DeprecationWarning)
... warnings.warn("deprecated", DeprecationWarning)
... warnings.warn("other", UserWarning)
(2x) /tmp/example.py:42 DeprecationWarning: deprecated
(1x) /tmp/example.py:43 UserWarning: other
>>> cw.counts
{('/tmp/example.py', 42, 'DeprecationWarning', 'deprecated'): 2,
('/tmp/example.py', 43, 'UserWarning', 'other'): 1}CollateWarnings(
print_on_exit: bool = True,
fmt: str = '({count}x) {filename}:{lineno} {category}: {message}'
)counts: collections.Counter[tuple[str, int, str, str]]
print_on_exit: bool
fmt: str
docs for
muutilsv0.8.12
muutils.console_unicodedef get_console_safe_str(default: str, fallback: str) -> strDetermine a console-safe string based on the preferred encoding.
This function attempts to encode a given default string
using the system’s preferred encoding. If encoding is successful, it
returns the default string; otherwise, it returns a
fallback string.
default : str The primary string intended for use, to
be tested against the system’s preferred encoding.fallback : str The alternative string to be used if
default cannot be encoded in the system’s preferred
encoding.str Either default or
fallback based on whether default can be
encoded safely.>>> get_console_safe_str("café", "cafe")
"café" # This result may vary based on the system's preferred encoding.docs for
muutilsv0.8.12
this code is based on an implementation of the Rust builtin
dbg! for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has
been significantly modified
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
PATH_MODEDEFAULT_VAL_JOINERdbgDBG_TENSOR_ARRAY_SUMMARY_DEFAULTSDBG_TENSOR_VAL_JOINERtensor_infoDBG_DICT_DEFAULTSDBG_LIST_DEFAULTSlist_infoTENSOR_STR_TYPESdict_infoinfo_autodbg_tensordbg_dictdbg_autogrep_reprmuutils.dbgthis code is based on an implementation of the Rust builtin
dbg! for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has
been significantly modified
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
PATH_MODE: Literal['relative', 'absolute'] = 'relative'
DEFAULT_VAL_JOINER: str = ' = '
def dbg(
exp: Union[~_ExpType, muutils.dbg._NoExpPassedSentinel] = <muutils.dbg._NoExpPassedSentinel object>,
formatter: Optional[Callable[[Any], str]] = None,
val_joiner: str = ' = '
) -> Union[~_ExpType, muutils.dbg._NoExpPassedSentinel]Call dbg with any variable or expression.
Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:
from <a href="">muutils.dbg</a> import dbg
a = 2
b = 5
dbg(a+b)
def square(x: int) -> int:
return x * x
dbg(square(a))
DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: Dict[str, Union[NoneType, bool, int, str]] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': True, 'sparkline_bins': 7, 'sparkline_logy': None, 'colored': True, 'eq_char': '='}
DBG_TENSOR_VAL_JOINER: str = ': '
def tensor_info(tensor: Any) -> strDBG_DICT_DEFAULTS: Dict[str, Union[bool, int, str]] = {'key_types': True, 'val_types': True, 'max_len': 32, 'indent': ' ', 'max_depth': 3}
DBG_LIST_DEFAULTS: Dict[str, Union[bool, int, str]] = {'max_len': 16, 'summary_show_types': True}
def list_info(lst: List[Any]) -> strTENSOR_STR_TYPES: Set[str] = {"<class 'numpy.ndarray'>", "<class 'torch.Tensor'>"}def dict_info(d: Dict[Any, Any], depth: int = 0) -> strdef info_auto(obj: Any) -> strAutomatically format an object for debugging.
def dbg_tensor(tensor: ~_ExpType) -> ~_ExpTypedbg function for tensors, using tensor_info formatter.
def dbg_dict(d: ~_ExpType_dict) -> ~_ExpType_dictdbg function for dictionaries, using dict_info formatter.
def dbg_auto(obj: ~_ExpType) -> ~_ExpTypedbg function for automatic formatting based on type.
def grep_repr(
obj: Any,
pattern: str | re.Pattern[str],
*,
char_context: int | None = 20,
line_context: int | None = None,
before_context: int = 0,
after_context: int = 0,
context: int | None = None,
max_count: int | None = None,
cased: bool = False,
loose: bool = False,
line_numbers: bool = False,
highlight: bool = True,
color: str = '31',
separator: str = '--',
quiet: bool = False
) -> Optional[List[str]]grep-like search on repr(obj) with improved grep-style
options.
By default, string patterns are case-insensitive. Pre-compiled regex patterns use their own flags.
Parameters: - obj: Object to search (its repr() string is scanned) - pattern: Regular expression pattern (string or pre-compiled) - char_context: Characters of context before/after each match (default: 20) - line_context: Lines of context before/after; overrides char_context - before_context: Lines of context before match (like grep -B) - after_context: Lines of context after match (like grep -A) - context: Lines of context before AND after (like grep -C) - max_count: Stop after this many matches - cased: Force case-sensitive search for string patterns - loose: Normalize spaces/punctuation for flexible matching - line_numbers: Show line numbers in output - highlight: Wrap matches with ANSI color codes - color: ANSI color code (default: “31” for red) - separator: Separator between multiple matches - quiet: Return results instead of printing
Returns: - None if quiet=False (prints to stdout) - List[str] if quiet=True (returns formatted output lines)
docs for
muutilsv0.8.12
making working with dictionaries easier
DefaulterDict: like a defaultdict, but default_factory
is passed the key as an argumentcondense_nested_dicts: condense a nested dict, by
condensing numeric or matching keys with matching values to rangescondense_tensor_dict: convert a dictionary of tensors
to a dictionary of shapeskwargs_to_nested_dict: given kwargs from fire, convert
them to a nested dictDefaulterDictdefaultdict_to_dict_recursivedotlist_to_nested_dictnested_dict_to_dotlistupdate_with_nested_dictkwargs_to_nested_dictis_numeric_consecutivecondense_nested_dicts_numeric_keyscondense_nested_dicts_matching_valuescondense_nested_dictstuple_dims_replaceTensorDictTensorIterableTensorDictFormatscondense_tensor_dictmuutils.dictmagicmaking working with dictionaries easier
DefaulterDict: like a defaultdict, but default_factory
is passed the key as an argumentcondense_nested_dicts: condense a nested dict, by
condensing numeric or matching keys with matching values to rangescondense_tensor_dict: convert a dictionary of tensors
to a dictionary of shapeskwargs_to_nested_dict: given kwargs from fire, convert
them to a nested dictclass DefaulterDict(typing.Dict[~_KT, ~_VT], typing.Generic[~_KT, ~_VT]):like a defaultdict, but default_factory is passed the key as an argument
default_factory: Callable[[~_KT], ~_VT]def defaultdict_to_dict_recursive(
dd: Union[collections.defaultdict, muutils.dictmagic.DefaulterDict]
) -> dictConvert a defaultdict or DefaulterDict to a normal dict, recursively
def dotlist_to_nested_dict(dot_dict: Dict[str, Any], sep: str = '.') -> Dict[str, Any]Convert a dict with dot-separated keys to a nested dict
Example:
>>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
def nested_dict_to_dotlist(
nested_dict: Dict[str, Any],
sep: str = '.',
allow_lists: bool = False
) -> dict[str, typing.Any]def update_with_nested_dict(
original: dict[str, typing.Any],
update: dict[str, typing.Any]
) -> dict[str, typing.Any]Update a dict with a nested dict
Example: >>> update_with_nested_dict({‘a’: {‘b’: 1}, “c”: -1}, {‘a’: {“b”: 2}}) {‘a’: {‘b’: 2}, ‘c’: -1}
original: dict[str, Any] the dict to update (will be
modified in-place)update: dict[str, Any] the dict to update withdict the updated dictdef kwargs_to_nested_dict(
kwargs_dict: dict[str, typing.Any],
sep: str = '.',
strip_prefix: Optional[str] = None,
when_unknown_prefix: Union[muutils.errormode.ErrorMode, str] = ErrorMode.Warn,
transform_key: Optional[Callable[[str], str]] = None
) -> dict[str, typing.Any]given kwargs from fire, convert them to a nested dict
if strip_prefix is not None, then all keys must start with the
prefix. by default, will warn if an unknown prefix is found, but can be
set to raise an error or ignore it:
when_unknown_prefix: ErrorMode
Example:
def main(**kwargs):
print(kwargs_to_nested_dict(kwargs))
fire.Fire(main)running the above script will give:
$ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}kwargs_dict: dict[str, Any] the kwargs dict to
convertsep: str = "." the separator to use for nested
keysstrip_prefix: Optional[str] = None if not None, then
all keys must start with this prefixwhen_unknown_prefix: ErrorMode = ErrorMode.WARN what to
do when an unknown prefix is foundtransform_key: Callable[[str], str] | None = None a
function to apply to each key before adding it to the dict (applied
after stripping the prefix)def is_numeric_consecutive(lst: list[str]) -> boolCheck if the list of keys is numeric and consecutive.
def condense_nested_dicts_numeric_keys(data: dict[str, typing.Any]) -> dict[str, typing.Any]condense a nested dict, by condensing numeric keys with matching values to ranges
>>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
{'[1-3]': 1, '[4-6]': 2}
>>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
{"1": {"[1-2]": "a"}, "2": "b"}def condense_nested_dicts_matching_values(
data: dict[str, typing.Any],
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None
) -> dict[str, typing.Any]condense a nested dict, by condensing keys with matching values
data : dict[str, Any] data to processval_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable) (defaults to None)def condense_nested_dicts(
data: dict[str, typing.Any],
condense_numeric_keys: bool = True,
condense_matching_values: bool = True,
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None
) -> dict[str, typing.Any]condense a nested dict, by condensing numeric or matching keys with matching values to ranges
combines the functionality of
condense_nested_dicts_numeric_keys() and
condense_nested_dicts_matching_values()
it’s not reversible because types are lost to make the printing pretty
data : dict[str, Any] data to processcondense_numeric_keys : bool whether to condense
numeric keys (e.g. “1”, “2”, “3”) to ranges (e.g. “[1-3]”) (defaults to
True)condense_matching_values : bool whether to condense
keys with matching values (defaults to True)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable) (defaults to None)def tuple_dims_replace(
t: tuple[int, ...],
dims_names_map: Optional[dict[int, str]] = None
) -> tuple[typing.Union[int, str], ...]TensorDict = typing.Dict[str, ForwardRef('torch.Tensor|np.ndarray')]
TensorIterable = typing.Iterable[typing.Tuple[str, ForwardRef('torch.Tensor|np.ndarray')]]
TensorDictFormats = typing.Literal['dict', 'json', 'yaml', 'yml']
def condense_tensor_dict(
data: 'TensorDict | TensorIterable',
fmt: Literal['dict', 'json', 'yaml', 'yml'] = 'dict',
*args,
shapes_convert: Callable[[tuple], Any] = <function _default_shapes_convert>,
drop_batch_dims: int = 0,
sep: str = '.',
dims_names_map: Optional[dict[int, str]] = None,
condense_numeric_keys: bool = True,
condense_matching_values: bool = True,
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
return_format: Optional[Literal['dict', 'json', 'yaml', 'yml']] = None
) -> Union[str, dict[str, str | tuple[int, ...]]]Convert a dictionary of tensors to a dictionary of shapes.
by default, values are converted to strings of their shapes (for nice
printing). If you want the actual shapes, set
shapes_convert = lambda x: x or
shapes_convert = None.
data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]
a either a TensorDict dict from strings to tensors, or an
TensorIterable iterable of (key, tensor) pairs (like you
might get from a dict().items()) )fmt : TensorDictFormats format to return the result in
– either a dict, or dump to json/yaml directly for pretty printing. will
crash if yaml is not installed. (defaults to 'dict')shapes_convert : Callable[[tuple], Any] conversion of a
shape tuple to a string or other format (defaults to turning it into a
string and removing quotes) (defaults to
lambdax:str(x).replace('"', '').replace("'", ''))drop_batch_dims : int number of leading dimensions to
drop from the shape (defaults to 0)sep : str separator to use for nested keys (defaults to
'.')dims_names_map : dict[int, str] | None convert certain
dimension values in shape. not perfect, can be buggy (defaults to
None)condense_numeric_keys : bool whether to condense
numeric keys (e.g. “1”, “2”, “3”) to ranges (e.g. “[1-3]”), passed on to
condense_nested_dicts (defaults to True)condense_matching_values : bool whether to condense
keys with matching values, passed on to
condense_nested_dicts (defaults to True)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable), passed on to condense_nested_dicts (defaults
to None)return_format : TensorDictFormats | None legacy alias
for fmt kwargstr|dict[str, str|tuple[int, ...]] dict if
return_format='dict', a string for json or
yaml output>>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
>>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))embed:
W_E: (50257, 768)
pos_embed:
W_pos: (1024, 768)
blocks:
'[0-11]':
attn:
'[W_Q, W_K, W_V]': (12, 768, 64)
W_O: (12, 64, 768)
'[b_Q, b_K, b_V]': (12, 64)
b_O: (768,)
mlp:
W_in: (768, 3072)
b_in: (3072,)
W_out: (3072, 768)
b_out: (768,)
unembed:
W_U: (768, 50257)
b_U: (50257,)ValueError : if return_format is not one
of ‘dict’, ‘json’, or ‘yaml’, or if you try to use ‘yaml’ output without
having PyYAML installeddocs for
muutilsv0.8.12
provides ErrorMode enum for handling errors
consistently
pass an error_mode: ErrorMode to a function to specify
how to handle a certain kind of exception. That function then instead of
raiseing or warnings.warning, calls
error_mode.process with the message and the exception.
you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning.
WarningFuncLoggingFuncGLOBAL_WARN_FUNCGLOBAL_LOG_FUNCcustom_showwarningErrorModeERROR_MODE_ALIASESmuutils.errormodeprovides ErrorMode enum for handling errors
consistently
pass an error_mode: ErrorMode to a function to specify
how to handle a certain kind of exception. That function then instead of
raiseing or warnings.warning, calls
error_mode.process with the message and the exception.
you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning.
class WarningFunc(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
WarningFunc(*args, **kwargs)LoggingFunc = typing.Callable[[str], NoneType]def GLOBAL_WARN_FUNC(unknown)Issue a warning, or maybe ignore it or raise an exception.
message Text of the warning message. category The Warning category subclass. Defaults to UserWarning. stacklevel How far up the call stack to make this warning appear. A value of 2 for example attributes the warning to the caller of the code calling warn(). source If supplied, the destroyed object which emitted a ResourceWarning skip_file_prefixes An optional tuple of module filename prefixes indicating frames to skip during stacklevel computations for stack frame attribution.
def GLOBAL_LOG_FUNC(*args, sep=' ', end='\n', file=None, flush=False)Prints the values to a stream, or to sys.stdout by default.
sep string inserted between values, default a space. end string appended after the last value, default a newline. file a file-like object (stream); defaults to the current sys.stdout. flush whether to forcibly flush the stream.
def custom_showwarning(
message: Warning | str,
category: Optional[Type[Warning]] = None,
filename: str | None = None,
lineno: int | None = None,
file: Optional[TextIO] = None,
line: Optional[str] = None
) -> Noneclass ErrorMode(enum.Enum):Enum for handling errors consistently
pass one of the instances of this enum to a function to specify how to handle a certain kind of exception.
That function then instead of raiseing or
warnings.warning, calls error_mode.process
with the message and the exception.
EXCEPT = ErrorMode.Except
WARN = ErrorMode.Warn
LOG = ErrorMode.Log
IGNORE = ErrorMode.Ignore
def process(
self,
msg: str,
except_cls: Type[Exception] = <class 'ValueError'>,
warn_cls: Type[Warning] = <class 'UserWarning'>,
except_from: Optional[Exception] = None,
warn_func: muutils.errormode.WarningFunc | None = None,
log_func: Optional[Callable[[str], NoneType]] = None
)process an exception or warning according to the error mode
msg : str message to pass to except_cls or
warn_funcexcept_cls : typing.Type[Exception] exception class to
raise, must be a subclass of Exception (defaults to
ValueError)warn_cls : typing.Type[Warning] warning class to use,
must be a subclass of Warning (defaults to
UserWarning)except_from : typing.Optional[Exception] will
raise except_cls(msg) from except_from if not
None (defaults to None)warn_func : WarningFunc | None function to use for
warnings, must have the signature
warn_func(msg: str, category: typing.Type[Warning], source: typing.Any = None) -> None
(defaults to None)log_func : LoggingFunc | None function to use for
logging, must have the signature
log_func(msg: str) -> None (defaults to
None)except_cls : descriptionexcept_cls : descriptionValueError : descriptiondef from_any(
cls,
mode: str | muutils.errormode.ErrorMode,
allow_aliases: bool = True,
allow_prefix: bool = True
) -> muutils.errormode.ErrorModeinitialize an ErrorMode from a string or an
ErrorMode instance
def serialize(self) -> strdef load(cls, data: str) -> muutils.errormode.ErrorModeERROR_MODE_ALIASES: dict[str, muutils.errormode.ErrorMode] = {'except': ErrorMode.Except, 'warn': ErrorMode.Warn, 'log': ErrorMode.Log, 'ignore': ErrorMode.Ignore, 'e': ErrorMode.Except, 'error': ErrorMode.Except, 'err': ErrorMode.Except, 'raise': ErrorMode.Except, 'w': ErrorMode.Warn, 'warning': ErrorMode.Warn, 'l': ErrorMode.Log, 'print': ErrorMode.Log, 'output': ErrorMode.Log, 'show': ErrorMode.Log, 'display': ErrorMode.Log, 'i': ErrorMode.Ignore, 'silent': ErrorMode.Ignore, 'quiet': ErrorMode.Ignore, 'nothing': ErrorMode.Ignore}
map of string aliases to ErrorMode instances
docs for
muutilsv0.8.12
group items by assuming that eq_func defines an
equivalence relation
muutils.group_equivgroup items by assuming that eq_func defines an
equivalence relation
def group_by_equivalence(
items_in: Sequence[~T],
eq_func: Callable[[~T, ~T], bool]
) -> list[list[~T]]group items by assuming that eq_func implies an
equivalence relation but might not be transitive
so, if f(a,b) and f(b,c) then f(a,c) might be false, but we still want to put [a,b,c] in the same class
note that lists are used to avoid the need for hashable items, and to allow for duplicates
items_in: Sequence[T] the items to groupeq_func: Callable[[T, T], bool] a function that returns
true if two items are equivalent. need not be transitivedocs for
muutilsv0.8.12
represents a mathematical Interval over the real
numbers
muutils.intervalrepresents a mathematical Interval over the real
numbers
Number = typing.Union[float, int]class Interval:Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
Interval(
*args: Union[Sequence[Union[float, int]], float, int],
is_closed: Optional[bool] = None,
closed_L: Optional[bool] = None,
closed_R: Optional[bool] = None
)lower: Union[float, int]
upper: Union[float, int]
closed_L: bool
closed_R: bool
singleton_set: Optional[set[Union[float, int]]]
is_closed: bool
is_open: boolis_half_open: boolis_singleton: boolis_empty: boolis_finite: boolsingleton: Union[float, int]def get_empty() -> muutils.interval.Intervaldef get_singleton(value: Union[float, int]) -> muutils.interval.Intervaldef numerical_contained(self, item: Union[float, int]) -> booldef interval_contained(self, item: muutils.interval.Interval) -> booldef from_str(cls, input_str: str) -> muutils.interval.Intervaldef copy(self) -> muutils.interval.Intervaldef size(self) -> floatReturns the size of the interval.
float the size of the intervaldef clamp(self, value: Union[int, float], epsilon: float = 1e-10) -> floatClamp the given value to the interval bounds.
For open bounds, the clamped value will be slightly inside the interval (by epsilon).
value : Union[int, float] the value to clamp.epsilon : float margin for open bounds (defaults to
_EPSILON)float the clamped valueValueError : If the input value is NaN.def intersection(self, other: muutils.interval.Interval) -> muutils.interval.Intervaldef union(self, other: muutils.interval.Interval) -> muutils.interval.Intervalclass ClosedInterval(Interval):Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
ClosedInterval(*args: Union[Sequence[float], float], **kwargs: Any)lowerupperclosed_Lclosed_Rsingleton_setis_closedis_openis_half_openis_singletonis_emptyis_finitesingletonget_emptyget_singletonnumerical_containedinterval_containedfrom_strcopysizeclampintersectionunionclass OpenInterval(Interval):Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
OpenInterval(*args: Union[Sequence[float], float], **kwargs: Any)lowerupperclosed_Lclosed_Rsingleton_setis_closedis_openis_half_openis_singletonis_emptyis_finitesingletonget_emptyget_singletonnumerical_containedinterval_containedfrom_strcopysizeclampintersectionuniondocs for
muutilsv0.8.12
submodule for serializing things to json in a recoverable way
you can throw any object into
muutils.json_serialize.json_serialize and it will return a
JSONitem, meaning a bool, int, float, str, None, list of
JSONitems, or a dict mappting to JSONitem.
The goal of this is if you want to just be able to store something as
relatively human-readable JSON, and don’t care as much about recovering
it, you can throw it into json_serialize and it will just
work. If you want to do so in a recoverable way, check out ZANJ.
it will do so by looking in DEFAULT_HANDLERS, which will
keep it as-is if its already valid, then try to find a
.serialize() method on the object, and then have a bunch of
special cases. You can add handlers by initializing a
JsonSerializer object and passing a sequence of them to
handlers_pre
additionally, SerializeableDataclass is a special kind
of dataclass where you specify how to serialize each field, and a
.serialize() method is automatically added to the class.
This is done by using the serializable_dataclass decorator,
inheriting from SerializeableDataclass, and
serializable_field in place of
dataclasses.field when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ library,
which extends this to support saving things to disk in a more efficient
way than just plain json (arrays are saved as npy files, for example),
and automatically detecting how to load saved objects into their
original classes.
json_serializeserializable_dataclassserializable_fieldarr_metadataload_arrayBASE_HANDLERSJSONitemJsonSerializertry_catchdc_eqSerializableDataclassmuutils.json_serializesubmodule for serializing things to json in a recoverable way
you can throw any object into
<a href="json_serialize/json_serialize.html">muutils.json_serialize.json_serialize</a>
and it will return a JSONitem, meaning a bool, int, float,
str, None, list of JSONitems, or a dict mappting to
JSONitem.
The goal of this is if you want to just be able to store something as
relatively human-readable JSON, and don’t care as much about recovering
it, you can throw it into json_serialize and it will just
work. If you want to do so in a recoverable way, check out ZANJ.
it will do so by looking in DEFAULT_HANDLERS, which will
keep it as-is if its already valid, then try to find a
.serialize() method on the object, and then have a bunch of
special cases. You can add handlers by initializing a
JsonSerializer object and passing a sequence of them to
handlers_pre
additionally, SerializeableDataclass is a special kind
of dataclass where you specify how to serialize each field, and a
.serialize() method is automatically added to the class.
This is done by using the serializable_dataclass decorator,
inheriting from SerializeableDataclass, and
serializable_field in place of
dataclasses.field when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ library,
which extends this to support saving things to disk in a more efficient
way than just plain json (arrays are saved as npy files, for example),
and automatically detecting how to load saved objects into their
original classes.
def json_serialize(
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]serialize object to json-serializable object with default config
def serializable_dataclass(
_cls=None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
properties_to_serialize: Optional[list[str]] = None,
register_handler: bool = True,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except,
on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn,
methods_no_override: list[str] | None = None,
**kwargs
)decorator to make a dataclass serializable. must also make it
inherit from SerializableDataclass!!
types will be validated (like pydantic) unless
on_typecheck_mismatch is set to
ErrorMode.IGNORE
behavior of most kwargs matches that of
dataclasses.dataclass, but with some additional kwargs. any
kwargs not listed here are passed to
dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine
fields.
If init is true, an __init__() method is added to the
class. If repr is true, a __repr__() method is added. If
order is true, rich comparison dunder methods are added. If unsafe_hash
is true, a __hash__() method function is added. If frozen
is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}_cls : _type_ class to decorate. don’t pass this arg,
just use this as a decorator (defaults to None)init : bool whether to add an __init__
method (passed to dataclasses.dataclass) (defaults to
True)repr : bool whether to add a __repr__
method (passed to dataclasses.dataclass) (defaults to
True)order : bool whether to add rich comparison methods
(passed to dataclasses.dataclass) (defaults to
False)unsafe_hash : bool whether to add a
__hash__ method (passed to dataclasses.dataclass)
(defaults to False)frozen : bool whether to make the class frozen
(passed to dataclasses.dataclass) (defaults to
False)properties_to_serialize : Optional[list[str]] which
properties to add to the serialized data dict
SerializableDataclass only (defaults to
None)register_handler : bool if true, register the class
with ZANJ for loading SerializableDataclass only
(defaults to True)on_typecheck_error : ErrorMode what to do if type
checking throws an exception (except, warn, ignore). If
ignore and an exception is thrown, type validation will
still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode what to do if a type
mismatch is found (except, warn, ignore). If ignore, type
validation will return True SerializableDataclass
onlymethods_no_override : list[str]|None list of methods
that should not be overridden by the decorator by default,
__eq__, serialize, load, and
validate_fields_types are overridden by this function, but
you can disable this if you’d rather write your own.
dataclasses.dataclass might still overwrite these, and
those options take precedence SerializableDataclass
only (defaults to None)**kwargs (passed to
dataclasses.dataclass)_type_ the decorated classKWOnlyError : only raised if kw_only is
True and python version is <3.9, since
dataclasses.dataclass does not support thisNotSerializableFieldException : if a field is not a
SerializableFieldFieldSerializationError : if there is an error
serializing a fieldAttributeError : if a property is not found on the
classFieldLoadingError : if there is an error loading a
fielddef serializable_field(
*_args,
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
doc: str | None = None,
metadata: Optional[mappingproxy] = None,
kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
**kwargs: Any
) -> AnyCreate a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
### ----------------------------------------------------------------------
### new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
serialize: whether to serialize this field when
serializing the class’serialization_fn: function taking the instance of the
field and returning a serializable object. If not provided, will iterate
through the SerializerHandlers defined in
<a href="json_serialize/json_serialize.html">muutils.json_serialize.json_serialize</a>loading_fn: function taking the serialized object and
returning the instance of the field. If not provided, will take object
as-is.deserialize_fn: new alternative to
loading_fn. takes only the field’s value, not the whole
class. if both loading_fn and deserialize_fn
are provided, an error will be raised.assert_type: whether to assert the type of the field
when loading. if False, will not check the type of the
field.custom_typecheck_fn: function taking the type of the
field and returning whether the type itself is valid. if not provided,
will use the default type checking.loading_fn takes the dict of the
class, not the field. if you wanted a
loading_fn that does nothing, you’d write:class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)using deserialize_fn instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)In the above code, my_field is an int but will be
serialized as a string.
note that if not using ZANJ, and you have a class inside a container,
you MUST provide serialization_fn and
loading_fn to serialize and load the container. ZANJ will
automatically do this for you.
custom_value_check_fn: function taking the value of the
field and returning whether the value itself is valid. if not provided,
any value is valid as long as it passes the type testdef arr_metadata(arr) -> dict[str, list[int] | str | int]get metadata for a numpy array
def load_array(
arr: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
) -> Anyload a json-serialized array, infer the mode if not specified
BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]
class JsonSerializer:Json serialization class (holds configs)
array_mode : ArrayMode how to write arrays (defaults to
"array_list_meta")error_mode : ErrorMode what to do when we can’t
serialize an object (will use repr as fallback if “ignore” or “warn”)
(defaults to "except")handlers_pre : MonoTuple[SerializerHandler] handlers to
use before the default handlers (defaults to tuple())handlers_default : MonoTuple[SerializerHandler] default
handlers to use (defaults to DEFAULT_HANDLERS)write_only_format : bool changes _FORMAT_KEY keys in
output to “write_format” (when you want to serialize
something in a way that zanj won’t try to recover the object when
loading) (defaults to False)ValueError: on init, if args is not
emptySerializationException: on
json_serialize(), if any error occurs when trying to
serialize an object and error_mode is set to
ErrorMode.EXCEPT"JsonSerializer(
*args,
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta',
error_mode: muutils.errormode.ErrorMode = ErrorMode.Except,
handlers_pre: None = (),
handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')),
write_only_format: bool = False
)array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
error_mode: muutils.errormode.ErrorMode
write_only_format: bool
handlers: None
def json_serialize(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]def hashify(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = (),
force: bool = True
) -> Union[bool, int, float, str, tuple]try to turn any object into something hashable
def try_catch(func: Callable)wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
def dc_eq(
dc1,
dc2,
except_when_class_mismatch: bool = False,
false_when_class_mismatch: bool = True,
except_when_field_mismatch: bool = False
) -> boolchecks if two dataclasses which (might) hold numpy arrays are equal
dc1: the first dataclassdc2: the second dataclassexcept_when_class_mismatch: bool if True,
will throw TypeError if the classes are different. if not,
will return false by default or attempt to compare the fields if
false_when_class_mismatch is False (default:
False)false_when_class_mismatch: bool only relevant if
except_when_class_mismatch is False. if
True, will return False if the classes are
different. if False, will attempt to compare the
fields.except_when_field_mismatch: bool only relevant if
except_when_class_mismatch is False and
false_when_class_mismatch is False. if
True, will throw TypeError if the fields are
different. (default: True)bool: True if the dataclasses are equal, False
otherwiseTypeError: if the dataclasses are of different
classesAttributeError: if the dataclasses have different
fields [START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
class SerializableDataclass(abc.ABC):Base class for serializable dataclasses
only for linting and type checking, still need to call
serializable_dataclass decorator
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]returns the class as a dict, implemented by using
@serializable_dataclass decorator
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~Ttakes in an appropriately structured dict and returns an instance of
the class, implemented by using @serializable_dataclass
decorator
def validate_fields_types(
self,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolvalidate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
def validate_field_type(
self,
field: muutils.json_serialize.serializable_field.SerializableField | str,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolgiven a dataclass, check the field matches the type hint
def diff(
self,
other: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
of_serialized: bool = False
) -> dict[str, typing.Any]get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}other : SerializableDataclass other instance to compare
againstof_serialized : bool if true, compare serialized data
and not raw values (defaults to False)dict[str, Any]ValueError : if the instances are not of the same
typeValueError : if the instances are
dataclasses.dataclass but not
SerializableDataclassdef update_from_nested_dict(self, nested_dict: dict[str, typing.Any])update the instance from a nested dict, useful for configuration from command line args
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
docs for
muutilsv0.8.12
this utilities module handles serialization and loading of numpy and torch arrays as json
array_list_meta is less efficient (arrays are stored as
nested lists), but preserves both metadata and human readability.array_b64_meta is the most efficient, but is not human
readable.external is mostly for use in ZANJmuutils.json_serialize.arraythis utilities module handles serialization and loading of numpy and torch arrays as json
array_list_meta is less efficient (arrays are stored as
nested lists), but preserves both metadata and human readability.array_b64_meta is the most efficient, but is not human
readable.external is mostly for use in ZANJArrayMode = typing.Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']def array_n_elements(arr) -> intget the number of elements in an array
def arr_metadata(arr) -> dict[str, list[int] | str | int]get metadata for a numpy array
def serialize_array(
jser: "'JsonSerializer'",
arr: numpy.ndarray,
path: Union[str, Sequence[str | int]],
array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]serialize a numpy or pytorch array in one of several modes
if the object is zero-dimensional, simply get the unique item
array_mode: ArrayMode can be one of: -
list: serialize as a list of values, no metadata
(equivalent to arr.tolist()) -
array_list_meta: serialize dict with metadata, actual list
under the key data - array_hex_meta: serialize
dict with metadata, actual hex string under the key data -
array_b64_meta: serialize dict with metadata, actual base64
string under the key data
for array_list_meta, array_hex_meta, and
array_b64_meta, the serialized object is:
{
_FORMAT_KEY: <array_list_meta|array_hex_meta>,
"shape": arr.shape,
"dtype": str(arr.dtype),
"data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
}
arr : Any array to serializearray_mode : ArrayMode mode in which to serialize the
array (defaults to None and inheriting from
jser: JsonSerializer)JSONitem json serialized arrayKeyError : if the array mode is not validdef infer_array_mode(
arr: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]
) -> Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']given a serialized array, infer the mode
assumes the array was serialized via
serialize_array()
def load_array(
arr: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
) -> Anyload a json-serialized array, infer the mode if not specified
docs for
muutilsv0.8.12
provides the basic framework for json serialization of objects
notably:
SerializerHandler defines how to serialize a specific
type of objectJsonSerializer handles configuration for which handlers
to usejson_serialize provides the default configuration if
you don’t care – call it on any object!SERIALIZER_SPECIAL_KEYSSERIALIZER_SPECIAL_FUNCSSERIALIZE_DIRECT_AS_STRObjectPathSerializerHandlerBASE_HANDLERSDEFAULT_HANDLERSJsonSerializerGLOBAL_JSON_SERIALIZERjson_serializemuutils.json_serialize.json_serializeprovides the basic framework for json serialization of objects
notably:
SerializerHandler defines how to serialize a specific
type of objectJsonSerializer handles configuration for which handlers
to usejson_serialize provides the default configuration if
you don’t care – call it on any object!SERIALIZER_SPECIAL_KEYS: None = ('__name__', '__doc__', '__module__', '__class__', '__dict__', '__annotations__')
SERIALIZER_SPECIAL_FUNCS: dict[str, typing.Callable] = {'str': <class 'str'>, 'dir': <built-in function dir>, 'type': <function <lambda>>, 'repr': <function <lambda>>, 'code': <function <lambda>>, 'sourcefile': <function <lambda>>}
SERIALIZE_DIRECT_AS_STR: Set[str] = {"<class 'torch.dtype'>", "<class 'torch.device'>"}
ObjectPath = tuple[typing.Union[str, int], ...]
class SerializerHandler:a handler for a specific type of object
- `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
- `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
- `desc : str` description of the handler (optional)
SerializerHandler(
check: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], bool],
serialize_func: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]],
uid: str,
desc: str
)check: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], bool]
serialize_func: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]
uid: str
desc: str
def serialize(self) -> dictserialize the handler info
BASE_HANDLERS: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
DEFAULT_HANDLERS: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects inSERIALIZE_DIRECT_AS_STRto strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings'))
class JsonSerializer:Json serialization class (holds configs)
array_mode : ArrayMode how to write arrays (defaults to
"array_list_meta")error_mode : ErrorMode what to do when we can’t
serialize an object (will use repr as fallback if “ignore” or “warn”)
(defaults to "except")handlers_pre : MonoTuple[SerializerHandler] handlers to
use before the default handlers (defaults to tuple())handlers_default : MonoTuple[SerializerHandler] default
handlers to use (defaults to DEFAULT_HANDLERS)write_only_format : bool changes _FORMAT_KEY keys in
output to “write_format” (when you want to serialize
something in a way that zanj won’t try to recover the object when
loading) (defaults to False)ValueError: on init, if args is not
emptySerializationException: on
json_serialize(), if any error occurs when trying to
serialize an object and error_mode is set to
ErrorMode.EXCEPT"JsonSerializer(
*args,
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta',
error_mode: muutils.errormode.ErrorMode = ErrorMode.Except,
handlers_pre: None = (),
handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')),
write_only_format: bool = False
)array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
error_mode: muutils.errormode.ErrorMode
write_only_format: bool
handlers: None
def json_serialize(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]def hashify(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = (),
force: bool = True
) -> Union[bool, int, float, str, tuple]try to turn any object into something hashable
GLOBAL_JSON_SERIALIZER: muutils.json_serialize.json_serialize.JsonSerializer = <muutils.json_serialize.json_serialize.JsonSerializer object>def json_serialize(
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]serialize object to json-serializable object with default config
docs for
muutilsv0.8.12
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj) will give you a dict, but
if some fields are not json-serializable, you will get an error when you
call json.dumps(d). This module provides a way around
that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
CantGetTypeHintsWarningZanjMissingWarningzanj_register_loader_serializable_dataclassFieldIsNotInitOrSerializeWarningSerializableDataclass__validate_field_typeSerializableDataclass__validate_fields_types__dictSerializableDataclass__validate_fields_typesSerializableDataclassget_cls_type_hints_cachedget_cls_type_hintsKWOnlyErrorFieldErrorNotSerializableFieldExceptionFieldSerializationErrorFieldLoadingErrorFieldTypeMismatchErrorserializable_dataclassmuutils.json_serialize.serializable_dataclasssave and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj) will give you a dict, but
if some fields are not json-serializable, you will get an error when you
call json.dumps(d). This module provides a way around
that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
class CantGetTypeHintsWarning(builtins.UserWarning):special warning for when we can’t get type hints
class ZanjMissingWarning(builtins.UserWarning):special warning for when ZANJ is missing
– register_loader_serializable_dataclass will not work
def zanj_register_loader_serializable_dataclass(cls: Type[~T])Register a serializable dataclass with the ZANJ import
this allows ZANJ().read() to load the class and not just
return plain dicts
class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):Base class for warnings generated by user code.
def SerializableDataclass__validate_field_type(
self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
field: muutils.json_serialize.serializable_field.SerializableField | str,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolgiven a dataclass, check the field matches the type hint
this function is written to
<a href="#SerializableDataclass.validate_field_type">SerializableDataclass.validate_field_type</a>
self : SerializableDataclass
SerializableDataclass instancefield : SerializableField | str field to validate, will
get from self.__dataclass_fields__ if an
stron_typecheck_error : ErrorMode what to do if type
checking throws an exception (except, warn, ignore). If
ignore and an exception is thrown, the function will return
False (defaults to
_DEFAULT_ON_TYPECHECK_ERROR)bool if the field type is correct. False
if the field type is incorrect or an exception is thrown and
on_typecheck_error is ignoredef SerializableDataclass__validate_fields_types__dict(
self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> dict[str, bool]validate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
returns a dict of field names to bools, where the bool is if the field type is valid
def SerializableDataclass__validate_fields_types(
self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolvalidate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
class SerializableDataclass(abc.ABC):Base class for serializable dataclasses
only for linting and type checking, still need to call
serializable_dataclass decorator
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]returns the class as a dict, implemented by using
@serializable_dataclass decorator
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~Ttakes in an appropriately structured dict and returns an instance of
the class, implemented by using @serializable_dataclass
decorator
def validate_fields_types(
self,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolvalidate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
def validate_field_type(
self,
field: muutils.json_serialize.serializable_field.SerializableField | str,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolgiven a dataclass, check the field matches the type hint
def diff(
self,
other: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
of_serialized: bool = False
) -> dict[str, typing.Any]get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}other : SerializableDataclass other instance to compare
againstof_serialized : bool if true, compare serialized data
and not raw values (defaults to False)dict[str, Any]ValueError : if the instances are not of the same
typeValueError : if the instances are
dataclasses.dataclass but not
SerializableDataclassdef update_from_nested_dict(self, nested_dict: dict[str, typing.Any])update the instance from a nested dict, useful for configuration from command line args
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
def get_cls_type_hints_cached(cls: Type[~T]) -> dict[str, typing.Any]cached typing.get_type_hints for a class
def get_cls_type_hints(cls: Type[~T]) -> dict[str, typing.Any]helper function to get type hints for a class
class KWOnlyError(builtins.NotImplementedError):kw-only dataclasses are not supported in python <3.9
class FieldError(builtins.ValueError):base class for field errors
class NotSerializableFieldException(FieldError):field is not a SerializableField
class FieldSerializationError(FieldError):error while serializing a field
class FieldLoadingError(FieldError):error while loading a field
class FieldTypeMismatchError(FieldError, builtins.TypeError):error when a field type does not match the type hint
def serializable_dataclass(
_cls=None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
properties_to_serialize: Optional[list[str]] = None,
register_handler: bool = True,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except,
on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn,
methods_no_override: list[str] | None = None,
**kwargs
)decorator to make a dataclass serializable. must also make it
inherit from SerializableDataclass!!
types will be validated (like pydantic) unless
on_typecheck_mismatch is set to
ErrorMode.IGNORE
behavior of most kwargs matches that of
dataclasses.dataclass, but with some additional kwargs. any
kwargs not listed here are passed to
dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine
fields.
If init is true, an __init__() method is added to the
class. If repr is true, a __repr__() method is added. If
order is true, rich comparison dunder methods are added. If unsafe_hash
is true, a __hash__() method function is added. If frozen
is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}_cls : _type_ class to decorate. don’t pass this arg,
just use this as a decorator (defaults to None)init : bool whether to add an __init__
method (passed to dataclasses.dataclass) (defaults to
True)repr : bool whether to add a __repr__
method (passed to dataclasses.dataclass) (defaults to
True)order : bool whether to add rich comparison methods
(passed to dataclasses.dataclass) (defaults to
False)unsafe_hash : bool whether to add a
__hash__ method (passed to dataclasses.dataclass)
(defaults to False)frozen : bool whether to make the class frozen
(passed to dataclasses.dataclass) (defaults to
False)properties_to_serialize : Optional[list[str]] which
properties to add to the serialized data dict
SerializableDataclass only (defaults to
None)register_handler : bool if true, register the class
with ZANJ for loading SerializableDataclass only
(defaults to True)on_typecheck_error : ErrorMode what to do if type
checking throws an exception (except, warn, ignore). If
ignore and an exception is thrown, type validation will
still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode what to do if a type
mismatch is found (except, warn, ignore). If ignore, type
validation will return True SerializableDataclass
onlymethods_no_override : list[str]|None list of methods
that should not be overridden by the decorator by default,
__eq__, serialize, load, and
validate_fields_types are overridden by this function, but
you can disable this if you’d rather write your own.
dataclasses.dataclass might still overwrite these, and
those options take precedence SerializableDataclass
only (defaults to None)**kwargs (passed to
dataclasses.dataclass)_type_ the decorated classKWOnlyError : only raised if kw_only is
True and python version is <3.9, since
dataclasses.dataclass does not support thisNotSerializableFieldException : if a field is not a
SerializableFieldFieldSerializationError : if there is an error
serializing a fieldAttributeError : if a property is not found on the
classFieldLoadingError : if there is an error loading a
fielddocs for
muutilsv0.8.12
extends dataclasses.Field for use with
SerializableDataclass
In particular, instead of using dataclasses.field, use
serializable_field to define fields in a
SerializableDataclass. You provide information on how the
field should be serialized and loaded (as well as anything that goes
into dataclasses.field) when you define the field, and the
SerializableDataclass will automatically use those
functions.
muutils.json_serialize.serializable_fieldextends dataclasses.Field for use with
SerializableDataclass
In particular, instead of using dataclasses.field, use
serializable_field to define fields in a
SerializableDataclass. You provide information on how the
field should be serialized and loaded (as well as anything that goes
into dataclasses.field) when you define the field, and the
SerializableDataclass will automatically use those
functions.
class SerializableField(dataclasses.Field):extension of dataclasses.Field with additional
serialization properties
SerializableField(
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Callable[[], Any], dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
doc: str | None = None,
metadata: Optional[mappingproxy] = None,
kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[<member 'type' of 'SerializableField' objects>], bool]] = None
)serialize: bool
serialization_fn: Optional[Callable[[Any], Any]]
loading_fn: Optional[Callable[[Any], Any]]
deserialize_fn: Optional[Callable[[Any], Any]]
assert_type: bool
custom_typecheck_fn: Optional[Callable[[<member 'type' of 'SerializableField' objects>], bool]]
def from_Field(
cls,
field: dataclasses.Field
) -> muutils.json_serialize.serializable_field.SerializableFieldcopy all values from a dataclasses.Field to new
SerializableField
name
type
default
default_factory
init
repr
hash
compare
metadata
kw_only
doc
def serializable_field(
*_args,
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
doc: str | None = None,
metadata: Optional[mappingproxy] = None,
kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
**kwargs: Any
) -> AnyCreate a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
### ----------------------------------------------------------------------
### new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
serialize: whether to serialize this field when
serializing the class’serialization_fn: function taking the instance of the
field and returning a serializable object. If not provided, will iterate
through the SerializerHandlers defined in
<a href="json_serialize.html">muutils.json_serialize.json_serialize</a>loading_fn: function taking the serialized object and
returning the instance of the field. If not provided, will take object
as-is.deserialize_fn: new alternative to
loading_fn. takes only the field’s value, not the whole
class. if both loading_fn and deserialize_fn
are provided, an error will be raised.assert_type: whether to assert the type of the field
when loading. if False, will not check the type of the
field.custom_typecheck_fn: function taking the type of the
field and returning whether the type itself is valid. if not provided,
will use the default type checking.loading_fn takes the dict of the
class, not the field. if you wanted a
loading_fn that does nothing, you’d write:class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)using deserialize_fn instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)In the above code, my_field is an int but will be
serialized as a string.
note that if not using ZANJ, and you have a class inside a container,
you MUST provide serialization_fn and
loading_fn to serialize and load the container. ZANJ will
automatically do this for you.
custom_value_check_fn: function taking the value of the
field and returning whether the value itself is valid. if not provided,
any value is valid as long as it passes the type testdocs for
muutilsv0.8.12
utilities for json_serialize
BaseTypeJSONitemJSONdictHashableitemUniversalContainerisinstance_namedtupletry_catchSerializationExceptionstring_as_linessafe_getsourcearray_safe_eqdc_eqMonoTuplemuutils.json_serialize.utilutilities for json_serialize
BaseType = typing.Union[bool, int, float, str, NoneType]
JSONitem = typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]
JSONdict = typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]]
Hashableitem = typing.Union[bool, int, float, str, tuple]
class UniversalContainer:contains everything – x in UniversalContainer() is
always True
def isinstance_namedtuple(x: Any) -> boolchecks if x is a namedtuple
credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
def try_catch(func: Callable)wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
class SerializationException(builtins.Exception):Common base class for all non-exit exceptions.
def string_as_lines(s: str | None) -> list[str]for easier reading of long strings in json, split up by newlines
sort of like how jupyter notebooks do it
def safe_getsource(func) -> list[str]def array_safe_eq(a: Any, b: Any) -> boolcheck if two objects are equal, account for if numpy arrays or torch tensors
def dc_eq(
dc1,
dc2,
except_when_class_mismatch: bool = False,
false_when_class_mismatch: bool = True,
except_when_field_mismatch: bool = False
) -> boolchecks if two dataclasses which (might) hold numpy arrays are equal
dc1: the first dataclassdc2: the second dataclassexcept_when_class_mismatch: bool if True,
will throw TypeError if the classes are different. if not,
will return false by default or attempt to compare the fields if
false_when_class_mismatch is False (default:
False)false_when_class_mismatch: bool only relevant if
except_when_class_mismatch is False. if
True, will return False if the classes are
different. if False, will attempt to compare the
fields.except_when_field_mismatch: bool only relevant if
except_when_class_mismatch is False and
false_when_class_mismatch is False. if
True, will throw TypeError if the fields are
different. (default: True)bool: True if the dataclasses are equal, False
otherwiseTypeError: if the dataclasses are of different
classesAttributeError: if the dataclasses have different
fields [START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
class MonoTuple:tuple type hint, but for a tuple of any length with all the same type
docs for
muutilsv0.8.12
utilities for reading and writing jsonlines files, including gzip support
muutils.jsonlinesutilities for reading and writing jsonlines files, including gzip support
def jsonl_load(
path: str,
/,
*,
use_gzip: bool | None = None
) -> list[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]]def jsonl_load_log(path: str, /, *, use_gzip: bool | None = None) -> list[dict]def jsonl_write(
path: str,
items: Sequence[Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]],
use_gzip: bool | None = None,
gzip_compresslevel: int = 2
) -> Nonedocs for
muutilsv0.8.12
anonymous getitem class
util for constructing a class which has a getitem method which just calls a function
a lambda is an anonymous function: kappa is the letter
before lambda in the greek alphabet, hence the name of this class
muutils.kappaanonymous getitem class
util for constructing a class which has a getitem method which just calls a function
a lambda is an anonymous function: kappa is the letter
before lambda in the greek alphabet, hence the name of this class
class Kappa(typing.Mapping[~_kappa_K, ~_kappa_V]):A Mapping is a generic container for associating key/value pairs.
This class provides concrete generic implementations of all methods except for getitem, iter, and len.
Kappa(func_getitem: Callable[[~_kappa_K], ~_kappa_V])func_getitem
doc
docs for
muutilsv0.8.12
(deprecated) experimenting with logging utilities
muutils.logger(deprecated) experimenting with logging utilities
class Logger(muutils.logger.simplelogger.SimpleLogger):logger with more features, including log levels and streams
- `log_path : str | None`
default log file path
(defaults to `None`)
- `log_file : AnyIO | None`
default log io, should have a `.write()` method (pass only this or `log_path`, not both)
(defaults to `None`)
- `timestamp : bool`
whether to add timestamps to every log message (under the `_timestamp` key)
(defaults to `True`)
- `default_level : int`
default log level for streams/messages that don't specify a level
(defaults to `0`)
- `console_print_threshold : int`
log level at which to print to the console, anything greater will not be printed unless overridden by `console_print`
(defaults to `50`)
- `level_header : HeaderFunction`
function for formatting log messages when printing to console
(defaults to `HEADER_FUNCTIONS["md"]`)
keep_last_msg_time : bool whether to keep the last
message time (defaults to True) - `ValueError` : _description_
Logger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int = 0,
console_print_threshold: int = 50,
level_header: muutils.logger.headerfuncs.HeaderFunction = <function md_header_function>,
streams: Union[dict[str | None, muutils.logger.loggingstream.LoggingStream], Sequence[muutils.logger.loggingstream.LoggingStream]] = (),
keep_last_msg_time: bool = True,
timestamp: bool = True,
**kwargs
)def log(
self,
msg: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] = None,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = False,
extra_indent: str = '',
**kwargs
)logging function
msg : JSONitem message (usually string or dict) to be
loggedlvl : int | None level of message (lower levels are
more important) (defaults to None)console_print : bool override
console_print_threshold setting (defaults to
False)stream : str | None whether to log to a stream
(defaults to None), which logs to the default
None stream (defaults to None)def log_elapsed_last(
self,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = True,
**kwargs
) -> floatlogs the time elapsed since the last message was printed to the console (in any stream)
def flush_all(self)flush all streams
class LoggingStream:properties of a logging stream
name: str name of the streamaliases: set[str] aliases for the stream (calls to
these names will be redirected to this stream. duplicate alises will
result in errors) TODO: perhaps duplicate alises should result in
duplicate writes?file: str|bool|AnyIO|None file to write to - if
None, will write to standard log - if True,
will write to name + ".log" - if False will
“write” to NullIO (throw it away) - if a string, will write
to that file - if a fileIO type object, will write to that objectdefault_level: int|None default level for this
streamdefault_contents: dict[str, Callable[[], Any]] default
contents for this streamlast_msg: tuple[float, Any]|None last message written
to this stream (timestamp, message)LoggingStream(
name: str | None,
aliases: set[str | None] = <factory>,
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int | None = None,
default_contents: dict[str, typing.Callable[[], typing.Any]] = <factory>,
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
)name: str | None
aliases: set[str | None]
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
default_level: int | None = None
default_contents: dict[str, typing.Callable[[], typing.Any]]
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
def make_handler(self) -> Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType]class SimpleLogger:logs training data to a jsonl file
SimpleLogger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
timestamp: bool = True
)def log(
self,
msg: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
console_print: bool = False,
**kwargs
)log a message to the log file, and optionally to the console
class TimerContext:context manager for timing code
start_time: float
end_time: float
elapsed_time: float
docs for
muutilsv0.8.12
muutils.logger.exception_contextclass ExceptionContext:context manager which catches all exceptions happening while the
context is open, .write() the exception trace to the given
stream, and then raises the exception
for example:
errorfile = open('error.log', 'w')
with ExceptionContext(errorfile):
# do something that might throw an exception
# if it does, the exception trace will be written to errorfile
# and then the exception will be raisedExceptionContext(stream)streamdocs for
muutilsv0.8.12
muutils.logger.headerfuncsclass HeaderFunction(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
HeaderFunction(*args, **kwargs)def md_header_function(
msg: Any,
lvl: int,
stream: str | None = None,
indent_lvl: str = ' ',
extra_indent: str = '',
**kwargs
) -> strstandard header function. will output
# {msg}
for levels in [0, 9]## {msg}
for levels in [10, 19], and so on[{stream}] # {msg}
for a non-`None` stream, with level headers as before!WARNING! [{stream}] {msg}
for level in [-9, -1]!!WARNING!! [{stream}] {msg}
for level in [-19, -10] and so onHEADER_FUNCTIONS: dict[str, muutils.logger.headerfuncs.HeaderFunction] = {'md': <function md_header_function>}
docs for
muutilsv0.8.12
muutils.logger.log_utildef get_any_from_stream(stream: list[dict], key: str) -> Noneget the first value of a key from a stream. errors if not found
def gather_log(file: str) -> dict[str, list[dict]]gathers and sorts all streams from a log
def gather_stream(file: str, stream: str) -> list[dict]gets all entries from a specific stream in a log file
def gather_val(
file: str,
stream: str,
keys: tuple[str],
allow_skip: bool = True
) -> list[list]gather specific keys from a specific stream in a log file
example: if “log.jsonl” has contents:
{"a": 1, "b": 2, "c": 3, "_stream": "s1"}
{"a": 4, "b": 5, "c": 6, "_stream": "s1"}
{"a": 7, "b": 8, "c": 9, "_stream": "s2"}
then gather_val("log.jsonl", "s1", ("a", "b")) will
return
[
[1, 2],
[4, 5]
]docs for
muutilsv0.8.12
logger with streams & levels, and a timer context manager
SimpleLogger is an extremely simple logger that can
write to both console and a fileLogger class handles levels in a slightly different way
than default python logging, and also has “streams” which
allow for different sorts of output in the same logger this was mostly
made with training models in mind and storing both metadata and
lossTimerContext is a context manager that can be used to
time the duration of a block of codemuutils.logger.loggerlogger with streams & levels, and a timer context manager
SimpleLogger is an extremely simple logger that can
write to both console and a fileLogger class handles levels in a slightly different way
than default python logging, and also has “streams” which
allow for different sorts of output in the same logger this was mostly
made with training models in mind and storing both metadata and
lossTimerContext is a context manager that can be used to
time the duration of a block of codedef decode_level(level: int) -> strclass Logger(muutils.logger.simplelogger.SimpleLogger):logger with more features, including log levels and streams
- `log_path : str | None`
default log file path
(defaults to `None`)
- `log_file : AnyIO | None`
default log io, should have a `.write()` method (pass only this or `log_path`, not both)
(defaults to `None`)
- `timestamp : bool`
whether to add timestamps to every log message (under the `_timestamp` key)
(defaults to `True`)
- `default_level : int`
default log level for streams/messages that don't specify a level
(defaults to `0`)
- `console_print_threshold : int`
log level at which to print to the console, anything greater will not be printed unless overridden by `console_print`
(defaults to `50`)
- `level_header : HeaderFunction`
function for formatting log messages when printing to console
(defaults to `HEADER_FUNCTIONS["md"]`)
keep_last_msg_time : bool whether to keep the last
message time (defaults to True) - `ValueError` : _description_
Logger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int = 0,
console_print_threshold: int = 50,
level_header: muutils.logger.headerfuncs.HeaderFunction = <function md_header_function>,
streams: Union[dict[str | None, muutils.logger.loggingstream.LoggingStream], Sequence[muutils.logger.loggingstream.LoggingStream]] = (),
keep_last_msg_time: bool = True,
timestamp: bool = True,
**kwargs
)def log(
self,
msg: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] = None,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = False,
extra_indent: str = '',
**kwargs
)logging function
msg : JSONitem message (usually string or dict) to be
loggedlvl : int | None level of message (lower levels are
more important) (defaults to None)console_print : bool override
console_print_threshold setting (defaults to
False)stream : str | None whether to log to a stream
(defaults to None), which logs to the default
None stream (defaults to None)def log_elapsed_last(
self,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = True,
**kwargs
) -> floatlogs the time elapsed since the last message was printed to the console (in any stream)
def flush_all(self)flush all streams
docs for
muutilsv0.8.12
muutils.logger.loggingstreamclass LoggingStream:properties of a logging stream
name: str name of the streamaliases: set[str] aliases for the stream (calls to
these names will be redirected to this stream. duplicate alises will
result in errors) TODO: perhaps duplicate alises should result in
duplicate writes?file: str|bool|AnyIO|None file to write to - if
None, will write to standard log - if True,
will write to name + ".log" - if False will
“write” to NullIO (throw it away) - if a string, will write
to that file - if a fileIO type object, will write to that objectdefault_level: int|None default level for this
streamdefault_contents: dict[str, Callable[[], Any]] default
contents for this streamlast_msg: tuple[float, Any]|None last message written
to this stream (timestamp, message)LoggingStream(
name: str | None,
aliases: set[str | None] = <factory>,
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int | None = None,
default_contents: dict[str, typing.Callable[[], typing.Any]] = <factory>,
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
)name: str | None
aliases: set[str | None]
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
default_level: int | None = None
default_contents: dict[str, typing.Callable[[], typing.Any]]
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
def make_handler(self) -> Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType]docs for
muutilsv0.8.12
muutils.logger.simpleloggerclass NullIO:null IO class
def write(self, msg: str) -> intwrite to nothing! this throws away the message
def flush(self) -> Noneflush nothing! this is a no-op
def close(self) -> Noneclose nothing! this is a no-op
AnyIO = typing.Union[typing.TextIO, muutils.logger.simplelogger.NullIO]class SimpleLogger:logs training data to a jsonl file
SimpleLogger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
timestamp: bool = True
)def log(
self,
msg: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
console_print: bool = False,
**kwargs
)log a message to the log file, and optionally to the console
docs for
muutilsv0.8.12
muutils.logger.timingclass TimerContext:context manager for timing code
start_time: float
end_time: float
elapsed_time: float
def filter_time_str(time: str) -> strassuming format h:mm:ss, clips off the hours if its
0
class ProgressEstimator:estimates progress and can give a progress bar
ProgressEstimator(
n_total: int,
pbar_fill: str = '█',
pbar_empty: str = ' ',
pbar_bounds: tuple[str, str] = ('|', '|')
)n_total: int
starttime: float
pbar_fill: str
pbar_empty: str
pbar_bounds: tuple[str, str]
total_str_len: int
def get_timing_raw(self, i: int) -> dict[str, float]returns dict(elapsed, per_iter, remaining, percent)
def get_pbar(self, i: int, width: int = 30) -> strreturns a progress bar
def get_progress_default(self, i: int) -> strreturns a progress string
docs for
muutilsv0.8.12
muutils.mathdocs for
muutilsv0.8.12
muutils.math.binsclass Bins:Bins(
n_bins: int = 32,
start: float = 0,
stop: float = 1.0,
scale: Literal['lin', 'log'] = 'log',
_log_min: float = 0.001,
_zero_in_small_start_log: bool = True
)n_bins: int = 32
start: float = 0
stop: float = 1.0
scale: Literal['lin', 'log'] = 'log'
edges: jaxtyping.Float[ndarray, 'n_bins+1']
centers: jaxtyping.Float[ndarray, 'n_bins']def changed_n_bins_copy(self, n_bins: int) -> muutils.math.bins.Binsdocs for
muutilsv0.8.12
muutils.math.matrix_powersdef matrix_powers(
A: jaxtyping.Float[ndarray, 'n n'],
powers: Sequence[int]
) -> jaxtyping.Float[ndarray, 'n_powers n n']Compute multiple powers of a matrix efficiently.
Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers.
A : Float[np.ndarray, "n n"] Square matrix to
exponentiatepowers : Sequence[int] List of powers to compute
(non-negative integers)dict[int, Float[np.ndarray, "n n"]] Dictionary mapping
each requested power to the corresponding matrix powerdef matrix_powers_torch(A, powers: Sequence[int])Compute multiple powers of a matrix efficiently.
Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers.
A : Float[torch.Tensor, "n n"] Square matrix to
exponentiatepowers : Sequence[int] List of powers to compute
(non-negative integers)Float[torch.Tensor, "n_powers n n"] Tensor containing
the requested matrix powers stacked along the first dimensionValueError : If no powers are requested or if A is not
a square matrixdocs for
muutilsv0.8.12
miscellaneous utilities
stable_hash for hashing that is stable across runsmuutils.misc.sequence for sequence manipulation,
applying mappings, and string-like operations on listsmuutils.misc.string for sanitizing things for
filenames, adjusting docstrings, and converting dicts to filenamesmuutils.misc.numerical for turning numbers into nice
strings and backmuutils.misc.freezing for freezing thingsmuutils.misc.classes for some weird class
utilitiesstable_hashWhenMissingempty_sequence_if_attr_falseflattenlist_splitlist_joinapply_mappingapply_mapping_chainsanitize_namesanitize_fnamesanitize_identifierdict_to_filenamedynamic_docstringshorten_numerical_to_strstr_to_numeric_SHORTEN_MAPFrozenDictFrozenListfreezeis_abstractget_all_subclassesisinstance_by_type_nameIsDataclassget_hashable_eq_attrsdataclass_set_equalsmuutils.miscmiscellaneous utilities
stable_hash for hashing that is stable across runs<a href="misc/sequence.html">muutils.misc.sequence</a>
for sequence manipulation, applying mappings, and string-like operations
on lists<a href="misc/string.html">muutils.misc.string</a>
for sanitizing things for filenames, adjusting docstrings, and
converting dicts to filenames<a href="misc/numerical.html">muutils.misc.numerical</a>
for turning numbers into nice strings and back<a href="misc/freezing.html">muutils.misc.freezing</a>
for freezing things<a href="misc/classes.html">muutils.misc.classes</a>
for some weird class utilitiesdef stable_hash(s: str | bytes) -> intReturns a stable hash of the given string. not cryptographically secure, but stable between runs
WhenMissing = typing.Literal['except', 'skip', 'include']def empty_sequence_if_attr_false(itr: Iterable[Any], attr_owner: Any, attr_name: str) -> Iterable[Any]Returns itr if attr_owner has the attribute
attr_name and it boolean casts to True.
Returns an empty sequence otherwise.
Particularly useful for optionally inserting delimiters into a
sequence depending on an TokenizerElement attribute.
itr: Iterable[Any] The iterable to return if the
attribute is True.attr_owner: Any The object to check for the
attribute.attr_name: str The name of the attribute to check.itr: Iterable if attr_owner has the
attribute attr_name and it boolean casts to
True, otherwise an empty sequence.() an empty sequence if the attribute is
False or not present.def flatten(it: Iterable[Any], levels_to_flatten: int | None = None) -> GeneratorFlattens an arbitrarily nested iterable. Flattens all iterable data
types except for str and bytes.
Generator over the flattened sequence.
it: Any arbitrarily nested iterable.levels_to_flatten: Number of levels to flatten by,
starting at the outermost layer. If None, performs full
flattening.def list_split(lst: list, val: Any) -> list[list]split a list into sublists by val. similar to
“a_b_c”.split(“_“)
>>> list_split([1,2,3,0,4,5,0,6], 0)
[[1, 2, 3], [4, 5], [6]]
>>> list_split([0,1,2,3], 0)
[[], [1, 2, 3]]
>>> list_split([1,2,3], 0)
[[1, 2, 3]]
>>> list_split([], 0)
[[]]def list_join(lst: list, factory: Callable) -> listadd a new instance of factory() between each
element of lst
>>> list_join([1,2,3], lambda : 0)
[1,0,2,0,3]
>>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1])
[1, 1600000000.0, 2, 1600000000.1, 3]def apply_mapping(
mapping: Mapping[~_AM_K, ~_AM_V],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, apply the mapping to the iterable with certain options
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, _AM_V] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns _AM_Viter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddef apply_mapping_chain(
mapping: Mapping[~_AM_K, Iterable[~_AM_V]],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, chain the mappings together
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, Iterable[_AM_V]] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns
Iterable[_AM_V]iter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddef sanitize_name(
name: str | None,
additional_allowed_chars: str = '',
replace_invalid: str = '',
when_none: str | None = '_None_',
leading_digit_prefix: str = ''
) -> strsanitize a string, leaving only alphanumerics and
additional_allowed_chars
name : str | None input stringadditional_allowed_chars : str additional characters to
allow, none by default (defaults to "")replace_invalid : str character to replace invalid
characters with (defaults to "")when_none : str | None string to return if
name is None. if None, raises an
exception (defaults to "_None_")leading_digit_prefix : str character to prefix the
string with if it starts with a digit (defaults to "")str sanitized stringdef sanitize_fname(fname: str | None, **kwargs) -> strsanitize a filename to posix standards
_ (underscore), ‘-’ (dash)
and . (period)def sanitize_identifier(fname: str | None, **kwargs) -> strsanitize an identifier (variable or function name)
_ (underscore)_ if it starts with a digitdef dict_to_filename(
data: dict,
format_str: str = '{key}_{val}',
separator: str = '.',
max_length: int = 255
)def dynamic_docstring(**doc_params)def shorten_numerical_to_str(
num: int | float,
small_as_decimal: bool = True,
precision: int = 1
) -> strshorten a large numerical value to a string 1234 -> 1K
precision guaranteed to 1 in 10, but can be higher. reverse of
str_to_numeric
def str_to_numeric(
quantity: str,
mapping: None | bool | dict[str, int | float] = True
) -> int | floatConvert a string representing a quantity to a numeric value.
The string can represent an integer, python float, fraction, or
shortened via shorten_numerical_to_str.
>>> str_to_numeric("5")
5
>>> str_to_numeric("0.1")
0.1
>>> str_to_numeric("1/5")
0.2
>>> str_to_numeric("-1K")
-1000.0
>>> str_to_numeric("1.5M")
1500000.0
>>> str_to_numeric("1.2e2")
120.0
_SHORTEN_MAP = {1000.0: 'K', 1000000.0: 'M', 1000000000.0: 'B', 1000000000000.0: 't', 1000000000000000.0: 'q', 1e+18: 'Q'}class FrozenDict(builtins.dict):class FrozenList(builtins.list):Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.
def append(self, value)Append object to the end of the list.
def extend(self, iterable)Extend list by appending elements from the iterable.
def insert(self, index, value)Insert object before index.
def remove(self, value)Remove first occurrence of value.
Raises ValueError if the value is not present.
def pop(self, index=-1)Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
def clear(self)Remove all items from list.
def freeze(instance: Any) -> Anyrecursively freeze an object in-place so that its attributes and elements cannot be changed
messy in the sense that sometimes the object is modified in place, but you can’t rely on that. always use the return value.
the gelidum package is a more complete implementation of this idea
def is_abstract(cls: type) -> boolReturns if a class is abstract.
def get_all_subclasses(class_: type, include_self=False) -> set[type]Returns a set containing all child classes in the subclass graph of
class_. I.e., includes subclasses of subclasses, etc.
include_self: Whether to include class_
itself in the returned setclass_: SuperclassSince most class hierarchies are small, the inefficiencies of the existing recursive implementation aren’t problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
def isinstance_by_type_name(o: object, type_name: str)Behaves like stdlib isinstance except it accepts a
string representation of the type rather than the type itself. This is a
hacky function intended to circumvent the need to import a type into a
module. It is susceptible to type name collisions.
o: Object (not the type itself) whose type to
interrogate type_name: The string returned by
type_.__name__. Generic types are not supported, only types
that would appear in type_.__mro__.
class IsDataclass(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
IsDataclass(*args, **kwargs)def get_hashable_eq_attrs(dc: muutils.misc.classes.IsDataclass) -> tuple[typing.Any]Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it’s not frozen.
def dataclass_set_equals(
coll1: Iterable[muutils.misc.classes.IsDataclass],
coll2: Iterable[muutils.misc.classes.IsDataclass]
) -> boolCompares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can’t be placed in sets since they’re not hashable. Collections of them may be compared using this function.
docs for
muutilsv0.8.12
is_abstractget_all_subclassesisinstance_by_type_nameIsDataclassget_hashable_eq_attrsdataclass_set_equalsmuutils.misc.classesdef is_abstract(cls: type) -> boolReturns if a class is abstract.
def get_all_subclasses(class_: type, include_self=False) -> set[type]Returns a set containing all child classes in the subclass graph of
class_. I.e., includes subclasses of subclasses, etc.
include_self: Whether to include class_
itself in the returned setclass_: SuperclassSince most class hierarchies are small, the inefficiencies of the existing recursive implementation aren’t problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
def isinstance_by_type_name(o: object, type_name: str)Behaves like stdlib isinstance except it accepts a
string representation of the type rather than the type itself. This is a
hacky function intended to circumvent the need to import a type into a
module. It is susceptible to type name collisions.
o: Object (not the type itself) whose type to
interrogate type_name: The string returned by
type_.__name__. Generic types are not supported, only types
that would appear in type_.__mro__.
class IsDataclass(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
IsDataclass(*args, **kwargs)def get_hashable_eq_attrs(dc: muutils.misc.classes.IsDataclass) -> tuple[typing.Any]Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it’s not frozen.
def dataclass_set_equals(
coll1: Iterable[muutils.misc.classes.IsDataclass],
coll2: Iterable[muutils.misc.classes.IsDataclass]
) -> boolCompares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can’t be placed in sets since they’re not hashable. Collections of them may be compared using this function.
docs for
muutilsv0.8.12
muutils.misc.freezingclass FrozenDict(builtins.dict):class FrozenList(builtins.list):Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.
def append(self, value)Append object to the end of the list.
def extend(self, iterable)Extend list by appending elements from the iterable.
def insert(self, index, value)Insert object before index.
def remove(self, value)Remove first occurrence of value.
Raises ValueError if the value is not present.
def pop(self, index=-1)Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
def clear(self)Remove all items from list.
def freeze(instance: Any) -> Anyrecursively freeze an object in-place so that its attributes and elements cannot be changed
messy in the sense that sometimes the object is modified in place, but you can’t rely on that. always use the return value.
the gelidum package is a more complete implementation of this idea
docs for
muutilsv0.8.12
FuncParamsFuncParamsPreWrapprocess_kwargvalidate_kwargreplace_kwargis_nonealways_truealways_falseformat_docstringLambdaArgstyped_lambdamuutils.misc.funcFuncParams = ~FuncParams
FuncParamsPreWrap = ~FuncParamsPreWrap
def process_kwarg(
kwarg_name: str,
processor: Callable[[~T_process_in], ~T_process_out]
) -> Callable[[Callable[~FuncParamsPreWrap, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]Decorator that applies a processor to a keyword argument.
The underlying function is expected to have a keyword argument (with
name kwarg_name) of type T_out, but the caller
provides a value of type T_in that is converted via
processor.
kwarg_name : str The name of the keyword argument to
process.processor : Callable[[T_in], T_out] A callable that
converts the input value (T_in) into the type expected by
the function (T_out).Callable[OutputParams, ReturnType] (expecting
kwarg_name of type T_out) into one of type
Callable[InputParams, ReturnType] (accepting
kwarg_name of type T_in).def validate_kwarg(
kwarg_name: str,
validator: Callable[[~T_kwarg], bool],
description: str | None = None,
action: muutils.errormode.ErrorMode = ErrorMode.Except
) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]Decorator that validates a specific keyword argument.
kwarg_name : str The name of the keyword argument to
validate.validator : Callable[[Any], bool] A callable that
returns True if the keyword argument is valid.description : str | None A message template if
validation fails.action : str Either "raise" (default) or
"warn".Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]
A decorator that validates the keyword argument.action=="warn", emits a
warning. Otherwise, raises a ValueError.@validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
def my_func(x: int) -> int:
return x
assert my_func(x=1) == 1ValueError if validation fails and
action == "raise".def replace_kwarg(
kwarg_name: str,
check: Callable[[~T_kwarg], bool],
replacement_value: ~T_kwarg,
replace_if_missing: bool = False
) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]Decorator that replaces a specific keyword argument value by identity comparison.
kwarg_name : str The name of the keyword argument to
replace.check : Callable[[T_kwarg], bool] A callable that
returns True if the keyword argument should be replaced.replacement_value : T_kwarg The value to replace
with.replace_if_missing : bool If True, replaces the keyword
argument even if it’s missing.Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]
A decorator that replaces the keyword argument value.kwargs[kwarg_name] if its value is
default_value.@replace_kwarg("x", None, "default_string")
def my_func(*, x: str | None = None) -> str:
return x
assert my_func(x=None) == "default_string"def is_none(value: Any) -> booldef always_true(value: Any) -> booldef always_false(value: Any) -> booldef format_docstring(
**fmt_kwargs: Any
) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]Decorator that formats a function’s docstring with the provided keyword arguments.
LambdaArgs = LambdaArgsdef typed_lambda(
fn: Callable[[Unpack[LambdaArgs]], ~ReturnType],
in_types: ~LambdaArgsTypes,
out_type: type[~ReturnType]
) -> Callable[[Unpack[LambdaArgs]], ~ReturnType]Wraps a lambda function with type hints.
fn : Callable[[Unpack[LambdaArgs]], ReturnType] The
lambda function to wrap.in_types : tuple[type, ...] Tuple of input types.out_type : type[ReturnType] The output type.Callable[..., ReturnType] A new function with
annotations matching the given signature.add = typed_lambda(lambda x, y: x + y, (int, int), int)
assert add(1, 2) == 3ValueError if the number of input types doesn’t match
the lambda’s parameters.docs for
muutilsv0.8.12
muutils.misc.hashingdef stable_hash(s: str | bytes) -> intReturns a stable hash of the given string. not cryptographically secure, but stable between runs
def stable_json_dumps(d) -> strdef base64_hash(s: str | bytes) -> strReturns a base64 representation of the hash of the given string. not cryptographically secure
docs for
muutilsv0.8.12
muutils.misc.numericaldef shorten_numerical_to_str(
num: int | float,
small_as_decimal: bool = True,
precision: int = 1
) -> strshorten a large numerical value to a string 1234 -> 1K
precision guaranteed to 1 in 10, but can be higher. reverse of
str_to_numeric
def str_to_numeric(
quantity: str,
mapping: None | bool | dict[str, int | float] = True
) -> int | floatConvert a string representing a quantity to a numeric value.
The string can represent an integer, python float, fraction, or
shortened via shorten_numerical_to_str.
>>> str_to_numeric("5")
5
>>> str_to_numeric("0.1")
0.1
>>> str_to_numeric("1/5")
0.2
>>> str_to_numeric("-1K")
-1000.0
>>> str_to_numeric("1.5M")
1500000.0
>>> str_to_numeric("1.2e2")
120.0
docs for
muutilsv0.8.12
WhenMissingempty_sequence_if_attr_falseflattenlist_splitlist_joinapply_mappingapply_mapping_chainmuutils.misc.sequenceWhenMissing = typing.Literal['except', 'skip', 'include']def empty_sequence_if_attr_false(itr: Iterable[Any], attr_owner: Any, attr_name: str) -> Iterable[Any]Returns itr if attr_owner has the attribute
attr_name and it boolean casts to True.
Returns an empty sequence otherwise.
Particularly useful for optionally inserting delimiters into a
sequence depending on an TokenizerElement attribute.
itr: Iterable[Any] The iterable to return if the
attribute is True.attr_owner: Any The object to check for the
attribute.attr_name: str The name of the attribute to check.itr: Iterable if attr_owner has the
attribute attr_name and it boolean casts to
True, otherwise an empty sequence.() an empty sequence if the attribute is
False or not present.def flatten(it: Iterable[Any], levels_to_flatten: int | None = None) -> GeneratorFlattens an arbitrarily nested iterable. Flattens all iterable data
types except for str and bytes.
Generator over the flattened sequence.
it: Any arbitrarily nested iterable.levels_to_flatten: Number of levels to flatten by,
starting at the outermost layer. If None, performs full
flattening.def list_split(lst: list, val: Any) -> list[list]split a list into sublists by val. similar to
“a_b_c”.split(“_“)
>>> list_split([1,2,3,0,4,5,0,6], 0)
[[1, 2, 3], [4, 5], [6]]
>>> list_split([0,1,2,3], 0)
[[], [1, 2, 3]]
>>> list_split([1,2,3], 0)
[[1, 2, 3]]
>>> list_split([], 0)
[[]]def list_join(lst: list, factory: Callable) -> listadd a new instance of factory() between each
element of lst
>>> list_join([1,2,3], lambda : 0)
[1,0,2,0,3]
>>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1])
[1, 1600000000.0, 2, 1600000000.1, 3]def apply_mapping(
mapping: Mapping[~_AM_K, ~_AM_V],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, apply the mapping to the iterable with certain options
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="../kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, _AM_V] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns _AM_Viter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddef apply_mapping_chain(
mapping: Mapping[~_AM_K, Iterable[~_AM_V]],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, chain the mappings together
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="../kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, Iterable[_AM_V]] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns
Iterable[_AM_V]iter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddocs for
muutilsv0.8.12
muutils.misc.stringdef sanitize_name(
name: str | None,
additional_allowed_chars: str = '',
replace_invalid: str = '',
when_none: str | None = '_None_',
leading_digit_prefix: str = ''
) -> strsanitize a string, leaving only alphanumerics and
additional_allowed_chars
name : str | None input stringadditional_allowed_chars : str additional characters to
allow, none by default (defaults to "")replace_invalid : str character to replace invalid
characters with (defaults to "")when_none : str | None string to return if
name is None. if None, raises an
exception (defaults to "_None_")leading_digit_prefix : str character to prefix the
string with if it starts with a digit (defaults to "")str sanitized stringdef sanitize_fname(fname: str | None, **kwargs) -> strsanitize a filename to posix standards
_ (underscore), ‘-’ (dash)
and . (period)def sanitize_identifier(fname: str | None, **kwargs) -> strsanitize an identifier (variable or function name)
_ (underscore)_ if it starts with a digitdef dict_to_filename(
data: dict,
format_str: str = '{key}_{val}',
separator: str = '.',
max_length: int = 255
)def dynamic_docstring(**doc_params)docs for
muutilsv0.8.12
miscellaneous utilities for ML pipelines
ARRAY_IMPORTSDEFAULT_SEEDGLOBAL_SEEDget_deviceset_reproducibilitychunksget_checkpoint_paths_for_runregister_methodpprint_summarymuutils.mlutilsmiscellaneous utilities for ML pipelines
ARRAY_IMPORTS: bool = True
DEFAULT_SEED: int = 42
GLOBAL_SEED: int = 42
def get_device(device: Union[str, torch.device, NoneType] = None) -> torch.deviceGet the torch.device instance on which torch.Tensors
should be allocated.
def set_reproducibility(seed: int = 42)Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
def chunks(it, chunk_size)Yield successive chunks from an iterator.
def get_checkpoint_paths_for_run(
run_path: pathlib._local.Path,
extension: Literal['pt', 'zanj'],
checkpoints_format: str = 'checkpoints/model.iter_*.{extension}'
) -> list[tuple[int, pathlib._local.Path]]get checkpoints of the format from the run_path
note that checkpoints_format should contain a glob
pattern with: - unresolved “{extension}” format term for the extension -
a wildcard for the iteration number
def register_method(
method_dict: dict[str, typing.Callable[..., typing.Any]],
custom_name: Optional[str] = None
) -> Callable[[~F], ~F]Decorator to add a method to the method_dict
def pprint_summary(summary: dict)docs for
muutilsv0.8.12
utilities for working with notebooks
configure_notebookconvert_ipynb_to_scriptrun_notebook_testsmermaid,
print_texmuutils.nbutilsutilities for working with notebooks
configure_notebookconvert_ipynb_to_scriptrun_notebook_testsmermaid,
print_texdef mm(graph)for plotting mermaid.js diagrams
docs for
muutilsv0.8.12
shared utilities for setting up a notebook
PlotlyNotInstalledWarningPLOTLY_IMPORTEDPlottingModePLOT_MODECONVERSION_PLOTMODE_OVERRIDEFIG_COUNTERFIG_OUTPUT_FMTFIG_NUMBERED_FNAMEFIG_CONFIGFIG_BASEPATHCLOSE_AFTER_PLOTSHOWMATPLOTLIB_FORMATSTIKZPLOTLIB_FORMATSUnknownFigureFormatWarninguniversal_savefigsetup_plotsconfigure_notebookplotshowmuutils.nbutils.configure_notebookshared utilities for setting up a notebook
class PlotlyNotInstalledWarning(builtins.UserWarning):Base class for warnings generated by user code.
PLOTLY_IMPORTED: bool = True
PlottingMode = typing.Literal['ignore', 'inline', 'widget', 'save']
PLOT_MODE: Literal['ignore', 'inline', 'widget', 'save'] = 'inline'
CONVERSION_PLOTMODE_OVERRIDE: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None
FIG_COUNTER: int = 0
FIG_OUTPUT_FMT: str | None = None
FIG_NUMBERED_FNAME: str = 'figure-{num}'
FIG_CONFIG: dict | None = None
FIG_BASEPATH: str | None = None
CLOSE_AFTER_PLOTSHOW: bool = False
MATPLOTLIB_FORMATS = ['pdf', 'png', 'jpg', 'jpeg', 'svg', 'eps', 'ps', 'tif', 'tiff']
TIKZPLOTLIB_FORMATS = ['tex', 'tikz']
class UnknownFigureFormatWarning(builtins.UserWarning):Base class for warnings generated by user code.
def universal_savefig(fname: str, fmt: str | None = None) -> Nonedef setup_plots(
plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline',
fig_output_fmt: str | None = 'pdf',
fig_numbered_fname: str = 'figure-{num}',
fig_config: dict | None = None,
fig_basepath: str | None = None,
close_after_plotshow: bool = False
) -> NoneSet up plot saving/rendering options
def configure_notebook(
*args,
seed: int = 42,
device: Any = None,
dark_mode: bool = True,
plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline',
fig_output_fmt: str | None = 'pdf',
fig_numbered_fname: str = 'figure-{num}',
fig_config: dict | None = None,
fig_basepath: str | None = None,
close_after_plotshow: bool = False
) -> torch.device | NoneShared Jupyter notebook setup steps
seed : int random seed across libraries including
torch, numpy, and random (defaults to 42) (defaults to
42)device : typing.Any pytorch device to use (defaults to
None)dark_mode : bool figures in dark mode (defaults to
True)plot_mode : PlottingMode how to display plots, one of
PlottingMode or
["ignore", "inline", "widget", "save"] (defaults to
"inline")fig_output_fmt : str | None format for saving figures
(defaults to "pdf")fig_numbered_fname : str format for saving figures with
numbers (if they aren’t named) (defaults to
"figure-{num}")fig_config : dict | None metadata to save with the
figures (defaults to None)fig_basepath : str | None base path for saving figures
(defaults to None)close_after_plotshow : bool close figures after showing
them (defaults to False)torch.device|None the device set, if torch is
installeddef plotshow(
fname: str | None = None,
plot_mode: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None,
fmt: str | None = None
)Show the active plot, depending on global configs
docs for
muutilsv0.8.12
fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.
muutils.nbutils.convert_ipynb_to_scriptfast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.
DISABLE_PLOTS: dict[str, list[str]] = {'matplotlib': ['\n# ------------------------------------------------------------\n# Disable matplotlib plots, done during processing byconvert_ipynb_to_script.py\nimport matplotlib.pyplot as plt\nplt.show = lambda: None\n# ------------------------------------------------------------\n'], 'circuitsvis': ['\n# ------------------------------------------------------------\n# Disable circuitsvis plots, done during processing byconvert_ipynb_to_script.py\nfrom circuitsvis.utils.convert_props import PythonProperty, convert_props\nfrom circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local\n\ndef new_render(\n react_element_name: str,\n **kwargs: PythonProperty\n) -> RenderedHTML:\n "return a visualization as raw HTML"\n local_src = render_local(react_element_name, **kwargs)\n cdn_src = render_cdn(react_element_name, **kwargs)\n # return as string instead of RenderedHTML for CI\n return str(RenderedHTML(local_src, cdn_src))\n\nrender = new_render\n# ------------------------------------------------------------\n'], 'muutils': ['import muutils.nbutils.configure_notebook as nb_conf\nnb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore"\n']}
DISABLE_PLOTS_WARNING: list[str] = ["# ------------------------------------------------------------\n# WARNING: this script is auto-generated byconvert_ipynb_to_script.py\n# showing plots has been disabled, so this is presumably in a temp dict for CI or something\n# so don't modify this code, it will be overwritten!\n# ------------------------------------------------------------\n"]
def disable_plots_in_script(script_lines: list[str]) -> list[str]Disable plots in a script by adding cursed things after the import statements
def convert_ipynb(
notebook: dict,
strip_md_cells: bool = False,
header_comment: str = '#%%',
disable_plots: bool = False,
filter_out_lines: Union[str, Sequence[str]] = ('%', '!')
) -> strConvert Jupyter Notebook to a script, doing some basic filtering and formatting.
- `notebook: dict`: Jupyter Notebook loaded as json.
- `strip_md_cells: bool = False`: Remove markdown cells from the output script.
- `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
- `disable_plots: bool = False`: Disable plots in the output script.
- `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
if a string is passed, it will be split by char and each char will be treated as a separate filter.
- `str`: Converted script.
def process_file(
in_file: str,
out_file: str | None = None,
strip_md_cells: bool = False,
header_comment: str = '#%%',
disable_plots: bool = False,
filter_out_lines: Union[str, Sequence[str]] = ('%', '!')
)def process_dir(
input_dir: Union[str, pathlib._local.Path],
output_dir: Union[str, pathlib._local.Path],
strip_md_cells: bool = False,
header_comment: str = '#%%',
disable_plots: bool = False,
filter_out_lines: Union[str, Sequence[str]] = ('%', '!')
)Convert all Jupyter Notebooks in a directory to scripts.
- `input_dir: str`: Input directory.
- `output_dir: str`: Output directory.
- `strip_md_cells: bool = False`: Remove markdown cells from the output script.
- `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
- `disable_plots: bool = False`: Disable plots in the output script.
- `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
if a string is passed, it will be split by char and each char will be treated as a separate filter.
docs for
muutilsv0.8.12
display mermaid.js diagrams in jupyter notebooks by the
mermaid.ink/img service
muutils.nbutils.mermaiddisplay mermaid.js diagrams in jupyter notebooks by the
mermaid.ink/img service
def mm(graph)for plotting mermaid.js diagrams
docs for
muutilsv0.8.12
quickly print a sympy expression in latex
muutils.nbutils.print_texquickly print a sympy expression in latex
def print_tex(
expr: sympy.core.expr.Expr,
name: str | None = None,
plain: bool = False,
rendered: bool = True
)function for easily rendering a sympy expression in latex
docs for
muutilsv0.8.12
turn a folder of notebooks into scripts, run them, and make sure they work.
made to be called as
python -m muutils.nbutils.run_notebook_tests --notebooks-dir <notebooks_dir> --converted-notebooks-temp-dir <converted_notebooks_temp_dir>muutils.nbutils.run_notebook_teststurn a folder of notebooks into scripts, run them, and make sure they work.
made to be called as
python -m <a href="">muutils.nbutils.run_notebook_tests</a> --notebooks-dir <notebooks_dir> --converted-notebooks-temp-dir <converted_notebooks_temp_dir>class NotebookTestError(builtins.Exception):Common base class for all non-exit exceptions.
SUCCESS_STR: str = '✅'
FAILURE_STR: str = '❌'
def run_notebook_tests(
notebooks_dir: pathlib._local.Path,
converted_notebooks_temp_dir: pathlib._local.Path,
CI_output_suffix: str = '.CI-output.txt',
run_python_cmd: Optional[str] = None,
run_python_cmd_fmt: str = '{python_tool} run python',
python_tool: str = 'poetry',
exit_on_first_fail: bool = False
)Run converted Jupyter notebooks as Python scripts and verify they execute successfully.
Takes a directory of notebooks and their corresponding converted Python scripts, executes each script, and captures the output. Failures are collected and reported, with optional early exit on first failure.
notebooks_dir : Path Directory containing the original
.ipynb notebook filesconverted_notebooks_temp_dir : Path Directory
containing the corresponding converted .py filesCI_output_suffix : str Suffix to append to output files
capturing execution results (defaults to
".CI-output.txt")run_python_cmd : str | None Custom command to run
Python scripts. Overrides python_tool and run_python_cmd_fmt if provided
(defaults to None)run_python_cmd_fmt : str Format string for constructing
the Python run command (defaults to
"{python_tool} run python")python_tool : str Tool used to run Python (e.g. poetry,
uv) (defaults to "poetry")exit_on_first_fail : bool Whether to raise exception
immediately on first notebook failure (defaults to
False)NoneNotebookTestError: If any notebooks fail to execute, or
if input directories are invalidTypeError: If run_python_cmd is provided but not a
string>>> run_notebook_tests(
... notebooks_dir=Path("notebooks"),
... converted_notebooks_temp_dir=Path("temp/converted"),
... python_tool="poetry"
... )
### testing notebooks in 'notebooks'
### reading converted notebooks from 'temp/converted'
Running 1/2: temp/converted/notebook1.py
Output in temp/converted/notebook1.CI-output.txt
{SUCCESS_STR} Run completed with return code 0docs for
muutilsv0.8.12
parallel processing utilities, chiefly
run_maybe_parallel
ProgressBarFunctionProgressBarOptionDEFAULT_PBAR_FNspinner_fn_wrapmap_kwargs_for_tqdmno_progress_fn_wrapset_up_progress_bar_fnrun_maybe_parallelmuutils.parallelparallel processing utilities, chiefly
run_maybe_parallel
class ProgressBarFunction(typing.Protocol):a protocol for a progress bar function
ProgressBarFunction(*args, **kwargs)ProgressBarOption = typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] = 'tqdm'
def spinner_fn_wrap(x: Iterable, **kwargs) -> Listspinner wrapper
def map_kwargs_for_tqdm(kwargs: dict) -> dictmap kwargs for tqdm, cant wrap because the pbar dissapears?
def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterablefallback to no progress bar
def set_up_progress_bar_fn(
pbar: Union[muutils.parallel.ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]],
pbar_kwargs: Optional[Dict[str, Any]] = None,
**extra_kwargs
) -> Tuple[muutils.parallel.ProgressBarFunction, dict]set up the progress bar function and its kwargs
pbar : Union[ProgressBarFunction, ProgressBarOption]
progress bar function or option. if a function, we return as-is. if a
string, we figure out which progress bar to usepbar_kwargs : Optional[Dict[str, Any]] kwargs passed to
the progress bar function (default to None) (defaults to
None)Tuple[ProgressBarFunction, dict] a tuple of the
progress bar function and its kwargsValueError : if pbar is not one of the
valid optionsdef run_maybe_parallel(
func: Callable[[~InputType], ~OutputType],
iterable: Iterable[~InputType],
parallel: Union[bool, int],
pbar_kwargs: Optional[Dict[str, Any]] = None,
chunksize: Optional[int] = None,
keep_ordered: bool = True,
use_multiprocess: bool = False,
pbar: Union[muutils.parallel.ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = 'tqdm'
) -> List[~OutputType]a function to make it easier to sometimes parallelize an operation
parallel is False, then the function
will run in serial, running map(func, iterable)parallel is True, then the function
will run in parallel, running in parallel with the maximum number of
processesparallel is an int, it must be greater
than 1, and the function will run in parallel with the number of
processes specified by parallelthe maximum number of processes is given by the
min(len(iterable), multiprocessing.cpu_count())
func : Callable[[InputType], OutputType] function
passed to either map or Pool.imapiterable : Iterable[InputType] iterable passed to
either map or Pool.imapparallel : bool | int whether to run in parallel, and
how many processes to usepbar_kwargs : Dict[str, Any] kwargs passed to the
progress bar functionList[OutputType] a list of the output of
func for each element in iterableValueError : if parallel is not a boolean
or an integer greater than 1ValueError : if use_multiprocess=True and
parallel=FalseImportError : if use_multiprocess=True and
multiprocess is not availabledocs for
muutilsv0.8.12
decorator spinner_decorator and context manager
SpinnerContext to display a spinner
using the base Spinner class while some code is
running.
DecoratedFunctionSpinnerConfigSpinnerConfigArgSPINNERSSpinnerNoOpContextManagerSpinnerContextspinner_decoratormuutils.spinnerdecorator spinner_decorator and context manager
SpinnerContext to display a spinner
using the base Spinner class while some code is
running.
DecoratedFunction = ~DecoratedFunctionDefine a generic type for the decorated function
class SpinnerConfig:SpinnerConfig(working: List[str] = <factory>, success: str = '✔️', fail: str = '❌')working: List[str]
success: str = '✔️'
fail: str = '❌'
def is_ascii(self) -> boolwhether all characters are ascii
def eq_lens(self) -> boolwhether all working characters are the same length
def is_valid(self) -> boolwhether the spinner config is valid
def from_any(
cls,
arg: Union[str, List[str], muutils.spinner.SpinnerConfig, dict]
) -> muutils.spinner.SpinnerConfigSpinnerConfigArg = typing.Union[str, typing.List[str], muutils.spinner.SpinnerConfig, dict]
SPINNERS: Dict[str, muutils.spinner.SpinnerConfig] = {'default': SpinnerConfig(working=['|', '/', '-', '\\'], success='#', fail='X'), 'dots': SpinnerConfig(working=['. ', '.. ', '...'], success='***', fail='xxx'), 'bars': SpinnerConfig(working=['| ', '|| ', '|||'], success='|||', fail='///'), 'arrows': SpinnerConfig(working=['<', '^', '>', 'v'], success='►', fail='✖'), 'arrows_2': SpinnerConfig(working=['←', '↖', '↑', '↗', '→', '↘', '↓', '↙'], success='→', fail='↯'), 'bouncing_bar': SpinnerConfig(working=['[ ]', '[= ]', '[== ]', '[=== ]', '[ ===]', '[ ==]', '[ =]'], success='[====]', fail='[XXXX]'), 'bar': SpinnerConfig(working=['[ ]', '[- ]', '[--]', '[ -]'], success='[==]', fail='[xx]'), 'bouncing_ball': SpinnerConfig(working=['( ● )', '( ● )', '( ● )', '( ● )', '( ●)', '( ● )', '( ● )', '( ● )', '( ● )', '(● )'], success='(●●●●●●)', fail='( ✖ )'), 'ooo': SpinnerConfig(working=['.', 'o', 'O', 'o'], success='O', fail='x'), 'braille': SpinnerConfig(working=['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'], success='⣿', fail='X'), 'clock': SpinnerConfig(working=['🕛', '🕐', '🕑', '🕒', '🕓', '🕔', '🕕', '🕖', '🕗', '🕘', '🕙', '🕚'], success='✔️', fail='❌'), 'hourglass': SpinnerConfig(working=['⏳', '⌛'], success='✔️', fail='❌'), 'square_corners': SpinnerConfig(working=['◰', '◳', '◲', '◱'], success='◼', fail='✖'), 'triangle': SpinnerConfig(working=['◢', '◣', '◤', '◥'], success='◆', fail='✖'), 'square_dot': SpinnerConfig(working=['⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽', '⣾'], success='⣿', fail='❌'), 'box_bounce': SpinnerConfig(working=['▌', '▀', '▐', '▄'], success='■', fail='✖'), 'hamburger': SpinnerConfig(working=['☱', '☲', '☴'], success='☰', fail='✖'), 'earth': SpinnerConfig(working=['🌍', '🌎', '🌏'], success='✔️', fail='❌'), 'growing_dots': SpinnerConfig(working=['⣀', '⣄', '⣤', '⣦', '⣶', '⣷', '⣿'], success='⣿', fail='✖'), 'dice': SpinnerConfig(working=['⚀', '⚁', '⚂', '⚃', '⚄', '⚅'], success='🎲', fail='✖'), 'wifi': SpinnerConfig(working=['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], success='✔️', fail='❌'), 'bounce': SpinnerConfig(working=['⠁', '⠂', '⠄', '⠂'], success='⠿', fail='⢿'), 'arc': SpinnerConfig(working=['◜', '◠', '◝', '◞', '◡', '◟'], success='○', fail='✖'), 'toggle': SpinnerConfig(working=['⊶', '⊷'], success='⊷', fail='⊗'), 'toggle2': SpinnerConfig(working=['▫', '▪'], success='▪', fail='✖'), 'toggle3': SpinnerConfig(working=['□', '■'], success='■', fail='✖'), 'toggle4': SpinnerConfig(working=['■', '□', '▪', '▫'], success='■', fail='✖'), 'toggle5': SpinnerConfig(working=['▮', '▯'], success='▮', fail='✖'), 'toggle7': SpinnerConfig(working=['⦾', '⦿'], success='⦿', fail='✖'), 'toggle8': SpinnerConfig(working=['◍', '◌'], success='◍', fail='✖'), 'toggle9': SpinnerConfig(working=['◉', '◎'], success='◉', fail='✖'), 'arrow2': SpinnerConfig(working=['⬆️ ', '↗️ ', '➡️ ', '↘️ ', '⬇️ ', '↙️ ', '⬅️ ', '↖️ '], success='➡️', fail='❌'), 'point': SpinnerConfig(working=['∙∙∙', '●∙∙', '∙●∙', '∙∙●', '∙∙∙'], success='●●●', fail='xxx'), 'layer': SpinnerConfig(working=['-', '=', '≡'], success='≡', fail='✖'), 'speaker': SpinnerConfig(working=['🔈 ', '🔉 ', '🔊 ', '🔉 '], success='🔊', fail='🔇'), 'orangePulse': SpinnerConfig(working=['🔸 ', '🔶 ', '🟠 ', '🟠 ', '🔷 '], success='🟠', fail='❌'), 'bluePulse': SpinnerConfig(working=['🔹 ', '🔷 ', '🔵 ', '🔵 ', '🔷 '], success='🔵', fail='❌'), 'satellite_signal': SpinnerConfig(working=['📡 ', '📡· ', '📡·· ', '📡···', '📡 ··', '📡 ·'], success='📡 ✔️ ', fail='📡 ❌ '), 'rocket_orbit': SpinnerConfig(working=['🌍🚀 ', '🌏 🚀 ', '🌎 🚀'], success='🌍 ✨', fail='🌍 💥'), 'ogham': SpinnerConfig(working=['ᚁ ', 'ᚂ ', 'ᚃ ', 'ᚄ', 'ᚅ'], success='᚛᚜', fail='✖'), 'eth': SpinnerConfig(working=['᛫', '፡', '፥', '፤', '፧', '።', '፨'], success='፠', fail='✖')}
class Spinner:displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float how often to update the spinner
display in seconds (defaults to 0.1)initial_value : str initial value to display with the
spinner (defaults to "")message : str message to display with the spinner
(defaults to "")format_string : str string to format the spinner with.
must have "\r" prepended to clear the line. allowed keys
are spinner, elapsed_time,
message, and value (defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}")output_stream : TextIO stream to write the spinner to
(defaults to sys.stdout)format_string_when_updated : Union[bool,str] whether to
use a different format string when the value is updated. if
True, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False)spinner_chars : Union[str, Sequence[str]] sequence of
strings, or key to look up in SPINNER_CHARS, to use as the
spinner characters (defaults to "default")spinner_complete : str string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE or "#")update_value(value: Any) -> None update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")@spinner_decorator
def long_running_function():
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")
return "Function completed"Spinner(
*args,
config: Union[str, List[str], muutils.spinner.SpinnerConfig, dict] = 'default',
update_interval: float = 0.1,
initial_value: str = '',
message: str = '',
format_string: str = '\r{spinner} ({elapsed_time:.2f}s) {message}{value}',
output_stream: <class 'TextIO'> = <_io.TextIOWrapper encoding='UTF-8'>,
format_string_when_updated: Union[str, bool] = False,
spinner_chars: Union[str, Sequence[str], NoneType] = None,
spinner_complete: Optional[str] = None,
**kwargs: Any
)config: muutils.spinner.SpinnerConfig
format_string_when_updated: Optional[str]
format string to use when the value is updated
update_interval: float
message: str
current_value: Any
format_string: str
output_stream: <class 'TextIO'>
start_time: float
for measuring elapsed time
stop_spinner: threading.Eventto stop the spinner
spinner_thread: Optional[threading.Thread]the thread running the spinner
value_changed: boolwhether the value has been updated since the last display
term_width: intwidth of the terminal, for padding with spaces
state: Literal['initialized', 'running', 'success', 'fail']def spin(self) -> NoneFunction to run in a separate thread, displaying the spinner and optional information
def update_value(self, value: Any) -> NoneUpdate the current value displayed by the spinner
def start(self) -> NoneStart the spinner
def stop(self, failed: bool = False) -> NoneStop the spinner
class NoOpContextManager(typing.ContextManager):A context manager that does nothing.
NoOpContextManager(*args, **kwargs)class SpinnerContext(Spinner, typing.ContextManager):displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float how often to update the spinner
display in seconds (defaults to 0.1)initial_value : str initial value to display with the
spinner (defaults to "")message : str message to display with the spinner
(defaults to "")format_string : str string to format the spinner with.
must have "\r" prepended to clear the line. allowed keys
are spinner, elapsed_time,
message, and value (defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}")output_stream : TextIO stream to write the spinner to
(defaults to sys.stdout)format_string_when_updated : Union[bool,str] whether to
use a different format string when the value is updated. if
True, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False)spinner_chars : Union[str, Sequence[str]] sequence of
strings, or key to look up in SPINNER_CHARS, to use as the
spinner characters (defaults to "default")spinner_complete : str string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE or "#")update_value(value: Any) -> None update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")@spinner_decorator
def long_running_function():
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")
return "Function completed"Spinnerconfigformat_string_when_updatedupdate_intervalmessagecurrent_valueformat_stringoutput_streamstart_timestop_spinnerspinner_threadvalue_changedterm_widthstatespinupdate_valuestartstopdef spinner_decorator(
*args,
config: Union[str, List[str], muutils.spinner.SpinnerConfig, dict] = 'default',
update_interval: float = 0.1,
initial_value: str = '',
message: str = '',
format_string: str = '{spinner} ({elapsed_time:.2f}s) {message}{value}',
output_stream: <class 'TextIO'> = <_io.TextIOWrapper encoding='UTF-8'>,
mutable_kwarg_key: Optional[str] = None,
spinner_chars: Union[str, Sequence[str], NoneType] = None,
spinner_complete: Optional[str] = None,
**kwargs
) -> Callable[[~DecoratedFunction], ~DecoratedFunction]displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float how often to update the spinner
display in seconds (defaults to 0.1)initial_value : str initial value to display with the
spinner (defaults to "")message : str message to display with the spinner
(defaults to "")format_string : str string to format the spinner with.
must have "\r" prepended to clear the line. allowed keys
are spinner, elapsed_time,
message, and value (defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}")output_stream : TextIO stream to write the spinner to
(defaults to sys.stdout)format_string_when_updated : Union[bool,str] whether to
use a different format string when the value is updated. if
True, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False)spinner_chars : Union[str, Sequence[str]] sequence of
strings, or key to look up in SPINNER_CHARS, to use as the
spinner characters (defaults to "default")spinner_complete : str string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE or "#")update_value(value: Any) -> None update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")@spinner_decorator
def long_running_function():
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")
return "Function completed"docs for
muutilsv0.8.12
StatCounter class for counting and calculating
statistics on numbers
cleaner and more efficient than just using a Counter or
array
muutils.statcounterStatCounter class for counting and calculating
statistics on numbers
cleaner and more efficient than just using a Counter or
array
NumericSequence = typing.Sequence[typing.Union[float, int, ForwardRef('NumericSequence')]]def universal_flatten(
arr: Union[Sequence[Union[float, int, Sequence[Union[float, int, ForwardRef('NumericSequence')]]]], float, int],
require_rectangular: bool = True
) -> Sequence[Union[float, int, ForwardRef('NumericSequence')]]flattens any iterable
class StatCounter(collections.Counter):Counter, but with some stat calculation methods which
assume the keys are numerical
works best when the keys are ints
def validate(self) -> boolvalidate the counter as being all floats or ints
def min(self)minimum value
def max(self)maximum value
def total(self)Sum of the counts
keys_sorted: listreturn the keys
def percentile(self, p: float)return the value at the given percentile
this could be log time if we did binary search, but that would be a lot of added complexity
def median(self) -> floatdef mean(self) -> floatreturn the mean of the values
def mode(self) -> floatdef std(self) -> floatreturn the standard deviation of the values
def summary(
self,
typecast: Callable = <function StatCounter.<lambda>>,
*,
extra_percentiles: Optional[list[float]] = None
) -> dict[str, typing.Union[float, int]]return a summary of the stats, without the raw data. human readable and small
def serialize(
self,
typecast: Callable = <function StatCounter.<lambda>>,
*,
extra_percentiles: Optional[list[float]] = None
) -> dictreturn a json-serializable version of the counter
includes both the output of summary and the raw
data:
{
"StatCounter": { <keys, values from raw data> },
"summary": self.summary(typecast, extra_percentiles=extra_percentiles),
}
### `def load` { #StatCounter.load }
```python
(cls, data: dict) -> muutils.statcounter.StatCounterload from a the output of
<a href="#StatCounter.serialize">StatCounter.serialize</a>
def from_list_arrays(
cls,
arr,
map_func: Callable = <class 'float'>
) -> muutils.statcounter.StatCountercalls map_func on each element of
universal_flatten(arr)
docs for
muutilsv0.8.12
utilities for getting information about the system, see
SysInfo class
muutils.sysinfoutilities for getting information about the system, see
SysInfo class
class SysInfo:getters for various information about the system
def python() -> dictdetails about python version
def pip() -> dictinstalled packages info
def pytorch() -> dictpytorch and cuda information
def platform() -> dictdef git_info(with_log: bool = False) -> dictdef get_all(
cls,
include: Optional[tuple[str, ...]] = None,
exclude: tuple[str, ...] = ()
) -> dictdocs for
muutilsv0.8.12
get metadata about a tensor, mostly for muutils.dbg
COLORSOutputFormatSYMBOLSSPARK_CHARSarray_infogenerate_sparklineDEFAULT_SETTINGSapply_colorcolorize_dtypeformat_shape_coloredformat_device_coloredarray_summarymuutils.tensor_infoget metadata about a tensor, mostly for
<a href="dbg.html">muutils.dbg</a>
COLORS: Dict[str, Dict[str, str]] = {'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'torch': '\\textcolor{orange}', 'dtype_bool': '\\textcolor{gray}', 'dtype_int': '\\textcolor{blue}', 'dtype_float': '\\textcolor{red!70}', 'dtype_str': '\\textcolor{red}', 'device_cuda': '\\textcolor{green}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'torch': '\x1b[38;5;208m', 'dtype_bool': '\x1b[38;5;245m', 'dtype_int': '\x1b[38;5;39m', 'dtype_float': '\x1b[38;5;167m', 'device_cuda': '\x1b[38;5;76m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'torch': '', 'dtype_bool': '', 'dtype_int': '', 'dtype_float': '', 'dtype_str': '', 'device_cuda': '', 'reset': ''}}
OutputFormat = typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] = {'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'distribution_log': '\\mathbb{P}_L', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'distribution_log': 'ℙ˪', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'distribution_log': 'dist_log', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}
Symbols for different formats
SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] = {'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}characters for sparklines in different formats
def array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]Extract statistical information from an array-like object.
A : array-like Array to analyze (numpy array or torch
tensor)Dict[str, Any] Dictionary containing raw statistical
information with numeric valuesdef generate_sparkline(
histogram: numpy.ndarray,
format: Literal['unicode', 'latex', 'ascii'] = 'unicode',
log_y: Optional[bool] = None
) -> tuple[str, bool]Generate a sparkline visualization of the histogram.
histogram : np.ndarray Histogram dataformat : Literal["unicode", "latex", "ascii"] Output
format (defaults to "unicode")log_y : bool|None Whether to use logarithmic y-scale.
None for automatic detection (defaults to
None)tuple[str, bool] Sparkline visualization and whether
log scale was used
DEFAULT_SETTINGS: Dict[str, Any] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': None, 'colored': False, 'as_list': False, 'eq_char': '='}
def apply_color(
text: str,
color_key: str,
colors: Dict[str, str],
using_tex: bool
) -> strdef colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> strColorize dtype string with specific colors for torch and type names.
def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> strFormat shape with proper coloring for both 1D and multi-D arrays.
def format_device_colored(device_str: str, colors: Dict[str, str], using_tex: bool) -> strFormat device string with CUDA highlighting.
def array_summary(
array,
fmt: Literal['unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>,
precision: int = <muutils.tensor_info._UseDefaultType object>,
stats: bool = <muutils.tensor_info._UseDefaultType object>,
shape: bool = <muutils.tensor_info._UseDefaultType object>,
dtype: bool = <muutils.tensor_info._UseDefaultType object>,
device: bool = <muutils.tensor_info._UseDefaultType object>,
requires_grad: bool = <muutils.tensor_info._UseDefaultType object>,
sparkline: bool = <muutils.tensor_info._UseDefaultType object>,
sparkline_bins: int = <muutils.tensor_info._UseDefaultType object>,
sparkline_logy: Optional[bool] = <muutils.tensor_info._UseDefaultType object>,
colored: bool = <muutils.tensor_info._UseDefaultType object>,
eq_char: str = <muutils.tensor_info._UseDefaultType object>,
as_list: bool = <muutils.tensor_info._UseDefaultType object>
) -> Union[str, List[str]]Format array information into a readable summary.
array array-like object (numpy array or torch
tensor)precision : int Decimal places (defaults to
2)format : Literal["unicode", "latex", "ascii"] Output
format (defaults to {default_fmt})stats : bool Whether to include statistical info (μ, σ,
x̃) (defaults to True)shape : bool Whether to include shape info (defaults to
True)dtype : bool Whether to include dtype info (defaults to
True)device : bool Whether to include device info for torch
tensors (defaults to True)requires_grad : bool Whether to include requires_grad
info for torch tensors (defaults to True)sparkline : bool Whether to include a sparkline
visualization (defaults to False)sparkline_width : int Width of the sparkline (defaults
to 20)sparkline_logy : bool|None Whether to use logarithmic
y-scale for sparkline (defaults to None)colored : bool Whether to add color to output (defaults
to False)as_list : bool Whether to return as list of strings
instead of joined string (defaults to False)Union[str, List[str]] Formatted statistical summary,
either as string or list of stringsdocs for
muutilsv0.8.12
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and
torch types to jaxtyping typesDTYPE_MAP mapping string representations of types to
their typeTORCH_DTYPE_MAP mapping string representations of types
to torch typescompare_state_dicts for comparing two state dicts and
giving a detailed error message on whether if was keys, shapes, or
values that didn’t matchTYPE_TO_JAX_DTYPEjaxtype_factoryATensorNDArraynumpy_to_torch_dtypeDTYPE_LISTDTYPE_MAPTORCH_DTYPE_MAPTORCH_OPTIMIZERS_MAPpad_tensorlpad_tensorrpad_tensorpad_arraylpad_arrayrpad_arrayget_dict_shapesstring_dict_shapesStateDictCompareErrorStateDictKeysErrorStateDictShapeErrorStateDictValueErrorcompare_state_dictsmuutils.tensor_utilsutilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and
torch types to jaxtyping typesDTYPE_MAP mapping string representations of types to
their typeTORCH_DTYPE_MAP mapping string representations of types
to torch typescompare_state_dicts for comparing two state dicts and
giving a detailed error message on whether if was keys, shapes, or
values that didn’t matchTYPE_TO_JAX_DTYPE: dict = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.longlong'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}dict mapping python, numpy, and torch types to jaxtyping
types
def jaxtype_factory(
name: str,
array_type: type,
default_jax_dtype=<class 'jaxtyping.Float'>,
legacy_mode: Union[muutils.errormode.ErrorMode, str] = ErrorMode.Warn
) -> typeusage:
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
ATensor = <class 'muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray'>
NDArray = <class 'muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray'>
def numpy_to_torch_dtype(dtype: Union[numpy.dtype, torch.dtype]) -> torch.dtypeconvert numpy dtype to torch dtype
DTYPE_LIST: list = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.longlong'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool'>]list of all the python, numpy, and torch numerical types I could think of
DTYPE_MAP: dict = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.longlong'>": <class 'numpy.longlong'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool'>": <class 'numpy.bool'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'float64': <class 'numpy.float64'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'longlong': <class 'numpy.longlong'>, 'uint8': <class 'numpy.uint8'>, 'bool': <class 'numpy.bool'>}mapping from string representations of types to their type
TORCH_DTYPE_MAP: dict = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int64, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.longlong'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool'>": torch.bool, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'longlong': torch.int64, 'uint8': torch.uint8, 'bool': torch.bool}mapping from string representations of types to specifically torch types
TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}def pad_tensor(
tensor: jaxtyping.Shaped[Tensor, 'dim1'],
padded_length: int,
pad_value: float = 0.0,
rpad: bool = False
) -> jaxtyping.Shaped[Tensor, 'padded_length']pad a 1-d tensor on the left with pad_value to length
padded_length
set rpad = True to pad on the right instead
def lpad_tensor(
tensor: torch.Tensor,
padded_length: int,
pad_value: float = 0.0
) -> torch.Tensorpad a 1-d tensor on the left with pad_value to length
padded_length
def rpad_tensor(
tensor: torch.Tensor,
pad_length: int,
pad_value: float = 0.0
) -> torch.Tensorpad a 1-d tensor on the right with pad_value to length
pad_length
def pad_array(
array: jaxtyping.Shaped[ndarray, 'dim1'],
padded_length: int,
pad_value: float = 0.0,
rpad: bool = False
) -> jaxtyping.Shaped[ndarray, 'padded_length']pad a 1-d array on the left with pad_value to length
padded_length
set rpad = True to pad on the right instead
def lpad_array(
array: numpy.ndarray,
padded_length: int,
pad_value: float = 0.0
) -> numpy.ndarraypad a 1-d array on the left with pad_value to length
padded_length
def rpad_array(
array: numpy.ndarray,
pad_length: int,
pad_value: float = 0.0
) -> numpy.ndarraypad a 1-d array on the right with pad_value to length
pad_length
def get_dict_shapes(d: dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]]given a state dict or cache dict, compute the shapes and put them in a nested dict
def string_dict_shapes(d: dict[str, torch.Tensor]) -> strprintable version of get_dict_shapes
class StateDictCompareError(builtins.AssertionError):raised when state dicts don’t match
class StateDictKeysError(StateDictCompareError):raised when state dict keys don’t match
class StateDictShapeError(StateDictCompareError):raised when state dict shapes don’t match
class StateDictValueError(StateDictCompareError):raised when state dict values don’t match
def compare_state_dicts(
d1: dict,
d2: dict,
rtol: float = 1e-05,
atol: float = 1e-08,
verbose: bool = True
) -> Nonecompare two dicts of tensors
d1 : dictd2 : dictrtol : float (defaults to 1e-5)atol : float (defaults to 1e-8)verbose : bool (defaults to True)StateDictKeysError : keys don’t matchStateDictShapeError : shapes don’t match (but keys
do)StateDictValueError : values don’t match (but keys and
shapes do)docs for
muutilsv0.8.12
timeit_fancy is just a fancier version of timeit with
more options
muutils.timeit_fancytimeit_fancy is just a fancier version of timeit with
more options
class FancyTimeitResult(typing.NamedTuple):return type of timeit_fancy
FancyTimeitResult(
timings: ForwardRef('StatCounter'),
return_value: ForwardRef('T_return'),
profile: ForwardRef('Union[pstats.Stats, None]')
)Create new instance of FancyTimeitResult(timings, return_value, profile)
timings: muutils.statcounter.StatCounterAlias for field number 0
return_value: ~T_returnAlias for field number 1
profile: Optional[pstats.Stats]Alias for field number 2
def timeit_fancy(
cmd: Union[Callable[[], ~T_return], str],
setup: Union[str, Callable[[], Any]] = <function <lambda>>,
repeats: int = 5,
namespace: Optional[dict[str, Any]] = None,
get_return: bool = True,
do_profiling: bool = False
) -> muutils.timeit_fancy.FancyTimeitResultWrapper for timeit to get the fastest run of a callable
with more customization options.
Approximates the functionality of the %timeit magic or command line interface in a Python callable.
cmd: Callable[[], T_return] | str The callable to time.
If a string, it will be passed to timeit.Timer as the
stmt argument.setup: str The setup code to run before
cmd. If a string, it will be passed to
timeit.Timer as the setup argument.repeats: int The number of times to run
cmd to get a reliable measurement.namespace: dict[str, Any] Passed to
timeit.Timer constructor. If cmd or
setup use local or global variables, they must be passed
here. See timeit documentation for details.get_return: bool Whether to pass the value returned
from cmd. If True, the return value will be appended in a
tuple with execution time. This is for speed and convenience so that
cmd doesn’t need to be run again in the calling scope if
the return values are needed. (default: False)do_profiling: bool Whether to return a
pstats.Stats object in addition to the time and return
value. (default: False)FancyTimeitResult, which is a NamedTuple with the
following fields:
time: float The time in seconds it took to run
cmd the minimum number of times to get a reliable
measurement.return_value: T|None The return value of
cmd if get_return is True,
otherwise None.profile: pstats.Stats|None A pstats.Stats
object if do_profiling is True, otherwise
None.docs for
muutilsv0.8.12
experimental utility for validating types in python, see
validate_type
GenericAliasTypesIncorrectTypeExceptionTypeHintNotImplementedErrorInvalidGenericAliasErrorvalidate_typeget_fn_allowed_kwargsmuutils.validate_typeexperimental utility for validating types in python, see
validate_type
GenericAliasTypes: tuple = (<class 'types.GenericAlias'>, <class 'typing._GenericAlias'>, <class 'typing._UnionGenericAlias'>, <class 'typing._BaseGenericAlias'>)class IncorrectTypeException(builtins.TypeError):Inappropriate argument type.
class TypeHintNotImplementedError(builtins.NotImplementedError):Method or function hasn’t been implemented yet.
class InvalidGenericAliasError(builtins.TypeError):Inappropriate argument type.
def validate_type(value: Any, expected_type: Any, do_except: bool = False) -> boolValidate that a value is of the
expected_type
value: the value to check the type ofexpected_type: the type to check against. Not all types
are supporteddo_except: if True, raise an exception if
the type is incorrect (instead of returning False)
(default: False)bool: True if the value is of the expected
type, False otherwise.IncorrectTypeException(TypeError): if the type is
incorrect and do_except is TrueTypeHintNotImplementedError(NotImplementedError): if
the type hint is not implementedInvalidGenericAliasError(TypeError): if the generic
alias is invaliduse typeguard for a more robust solution:
https://github.com/agronholm/typeguard
def get_fn_allowed_kwargs(fn: Callable) -> Set[str]Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.
docs for
muutilsv0.8.12
muutils.webdocs for
muutilsv0.8.12
Inline / bundle external assets (CSS, JS, SVG, PNG) into an HTML document.
Default mode uses zero external dependencies and a
few well-targeted regular expressions. If you install
beautifulsoup4 you can enable the far more robust BS4 mode by
passing InlineConfig(use_bs4=True).
AssetExtDEFAULT_ALLOWED_EXTENSIONSDEFAULT_TAG_ATTRMIME_BY_EXTInlineConfiginline_html_assetsinline_html_filemuutils.web.bundle_htmlInline / bundle external assets (CSS, JS, SVG, PNG) into an HTML document.
Default mode uses zero external dependencies and a
few well-targeted regular expressions. If you install
beautifulsoup4 you can enable the far more robust BS4 mode by
passing InlineConfig(use_bs4=True).
AssetExt = typing.Literal['.css', '.js', '.svg', '.png']
DEFAULT_ALLOWED_EXTENSIONS: Final[set[Literal['.css', '.js', '.svg', '.png']]] = {'.svg', '.js', '.css', '.png'}
DEFAULT_TAG_ATTR: Final[dict[str, str]] = {'link': 'href', 'script': 'src', 'img': 'src', 'use': 'xlink:href'}
MIME_BY_EXT: Final[dict[Literal['.css', '.js', '.svg', '.png'], str]] = {'.css': 'text/css', '.js': 'application/javascript', '.svg': 'image/svg+xml', '.png': 'image/png'}
class InlineConfig:High-level configuration for the inliner.
allowed_extensions : set[AssetExt] Extensions that may
be inlined.tag_attr : dict[str, str] Mapping tag ->
attribute that holds the asset reference.max_bytes : int Assets larger than this are
ignored.local : bool Allow local filesystem assets.remote : bool Allow remote http/https assets.include_filename_comments : bool Surround every
replacement with <!-- begin '...' --> and
<!-- end '...' -->.use_bs4 : bool Parse the document with BeautifulSoup if
available.InlineConfig(
allowed_extensions: set[typing.Literal['.css', '.js', '.svg', '.png']] = <factory>,
tag_attr: dict[str, str] = <factory>,
max_bytes: int = 131072,
local: bool = True,
remote: bool = False,
include_filename_comments: bool = True,
use_bs4: bool = False
)allowed_extensions: set[typing.Literal['.css', '.js', '.svg', '.png']]
tag_attr: dict[str, str]
max_bytes: int = 131072
local: bool = True
remote: bool = False
include_filename_comments: bool = True
use_bs4: bool = False
def inline_html_assets(
html: str,
*,
base_path: pathlib._local.Path,
config: muutils.web.bundle_html.InlineConfig | None = None,
prettify: bool = False
) -> strInline permitted external assets inside html.
html : str Raw HTML text.base_path : Path Directory used to resolve relative
asset paths.config : InlineConfig | None Inlining options (see
InlineConfig).prettify : bool Pretty-print output (only effective in
BS4 mode).str Modified HTML.def inline_html_file(
html_path: pathlib._local.Path,
output_path: pathlib._local.Path,
base_path: pathlib._local.Path | None = None,
config: muutils.web.bundle_html.InlineConfig | None = None,
prettify: bool = False
) -> pathlib._local.PathRead html_path, inline its assets, and write the result.
html_path : Path Source HTML file.output_path : Path Destination path to write the
modified HTML.base_path : Path | None Directory used to resolve
relative asset paths (defaults to the HTML file’s directory). If
None, uses the directory of html_path. (default:
None -> use html_path.parent)config : InlineConfig | None Inlining options. If
None, uses default configuration. (default:
None -> use InlineConfig())prettify : bool Pretty-print when
use_bs4=True. (default: False)Path Path actually written.