docs for
muutils
v0.8.7
muutils
, stylized as “μutils” or “μutils”, is a collection
of miscellaneous python utilities, meant to be small and with no
dependencies outside of standard python.
PyPi: muutils
pip install muutils
Note that for using mlutils
, tensor_utils
,
nbutils.configure_notebook
, or the array serialization
features of json_serialize
, you will need to install with
optional array
dependencies:
pip install muutils[array]
hosted html docs: https://miv.name/muutils
statcounter
an extension of collections.Counter
that provides
“smart” computation of stats (mean, variance, median, other percentiles)
from the counter object without using
Counter.elements()
dictmagic
has utilities for working with dictionaries, like:
python >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} >>> nested_dict_to_dotlist({'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}) {'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}
DefaulterDict
which works like a
defaultdict
but can generate the default value based on the
keycondense_tensor_dict
takes a dict of dotlist-tensors
and gives a more human-readable summary:
python >>> model = MyGPT() >>> print(condense_tensor_dict(model.named_parameters(), 'yaml'))
yaml embed: W_E: (50257, 768) pos_embed: W_pos: (1024, 768) blocks: '[0-11]': attn: '[W_Q, W_K, W_V]': (12, 768, 64) W_O: (12, 64, 768) '[b_Q, b_K, b_V]': (12, 64) b_O: (768,) <...>
kappa
Anonymous gettitem, so you can do things like
>>> k = Kappa(lambda x: x**2)
>>> k[2]
4
sysinfo
utility for getting a bunch of system information. useful for logging.
misc
:contains a few utilities: - stable_hash()
uses
hashlib.sha256
to compute a hash of an object that is
stable across runs of python - list_join
and
list_split
which behave like str.join
and
str.split
but for lists - sanitize_fname
and
dict_to_filename
for simplifying the creation of unique
filename - shorten_numerical_to_str()
and
str_to_numeric
turns numbers like 123456789
into "123M"
and back - freeze
, which prevents
an object from being modified. Also see gelidum
nbutils
contains utilities for working with jupyter notebooks, such as:
json_serialize
a tool for serializing and loading arbitrary python objects into
json. plays nicely with ZANJ
tensor_utils
]contains minor utilities for working with pytorch tensors and numpy arrays, mostly for making type conversions easier
group_equiv
groups elements from a sequence according to a given equivalence relation, without assuming that the equivalence relation obeys the transitive property
jsonlines
an extremely simple utility for reading/writing jsonl
files
ZANJ
is a human-readable and simple format for ML models, datasets, and
arbitrary objects. It’s build around having a zip file with
json
and npy
files, and has been spun off into
its own project.
There are a couple work-in-progress utilities in _wip
that aren’t ready for anything, but nothing in this repo is suitable for
production. Use at your own risk!
json_serialize
logger
misc
nbutils
console_unicode
dbg
dictmagic
errormode
group_equiv
interval
jsonlines
kappa
mlutils
parallel
spinner
statcounter
sysinfo
tensor_info
tensor_utils
timeit_fancy
validate_type
muutils
muutils
, stylized as “μutils” or “μutils”, is a collection
of miscellaneous python utilities, meant to be small and with no
dependencies outside of standard python.
PyPi: muutils
pip install muutils
Note that for using mlutils
, tensor_utils
,
nbutils.configure_notebook
, or the array serialization
features of json_serialize
, you will need to install with
optional array
dependencies:
pip install muutils[array]
hosted html docs: https://miv.name/muutils
statcounter
an extension of collections.Counter
that provides
“smart” computation of stats (mean, variance, median, other percentiles)
from the counter object without using
Counter.elements()
dictmagic
has utilities for working with dictionaries, like:
python >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} >>> nested_dict_to_dotlist({'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}) {'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}
DefaulterDict
which works like a
defaultdict
but can generate the default value based on the
keycondense_tensor_dict
takes a dict of dotlist-tensors
and gives a more human-readable summary:
python >>> model = MyGPT() >>> print(condense_tensor_dict(model.named_parameters(), 'yaml'))
yaml embed: W_E: (50257, 768) pos_embed: W_pos: (1024, 768) blocks: '[0-11]': attn: '[W_Q, W_K, W_V]': (12, 768, 64) W_O: (12, 64, 768) '[b_Q, b_K, b_V]': (12, 64) b_O: (768,) <...>
kappa
Anonymous gettitem, so you can do things like
>>> k = Kappa(lambda x: x**2)
>>> k[2]
4
sysinfo
utility for getting a bunch of system information. useful for logging.
misc
:contains a few utilities: - stable_hash()
uses
hashlib.sha256
to compute a hash of an object that is
stable across runs of python - list_join
and
list_split
which behave like str.join
and
str.split
but for lists - sanitize_fname
and
dict_to_filename
for simplifying the creation of unique
filename - shorten_numerical_to_str()
and
str_to_numeric
turns numbers like 123456789
into "123M"
and back - freeze
, which prevents
an object from being modified. Also see gelidum
nbutils
contains utilities for working with jupyter notebooks, such as:
json_serialize
a tool for serializing and loading arbitrary python objects into
json. plays nicely with ZANJ
tensor_utils
]contains minor utilities for working with pytorch tensors and numpy arrays, mostly for making type conversions easier
group_equiv
groups elements from a sequence according to a given equivalence relation, without assuming that the equivalence relation obeys the transitive property
jsonlines
an extremely simple utility for reading/writing jsonl
files
ZANJ
is a human-readable and simple format for ML models, datasets, and
arbitrary objects. It’s build around having a zip file with
json
and npy
files, and has been spun off into
its own project.
There are a couple work-in-progress utilities in _wip
that aren’t ready for anything, but nothing in this repo is suitable for
production. Use at your own risk!
docs for
muutils
v0.8.7
muutils.console_unicode
def get_console_safe_str
str, fallback: str) -> str (default:
Determine a console-safe string based on the preferred encoding.
This function attempts to encode a given default
string
using the system’s preferred encoding. If encoding is successful, it
returns the default
string; otherwise, it returns a
fallback
string.
default : str
The primary string intended for use, to
be tested against the system’s preferred encoding.fallback : str
The alternative string to be used if
default
cannot be encoded in the system’s preferred
encoding.str
Either default
or
fallback
based on whether default
can be
encoded safely.>>> get_console_safe_str("café", "cafe")
"café" # This result may vary based on the system's preferred encoding.
docs for
muutils
v0.8.7
this code is based on an implementation of the Rust builtin
dbg!
for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has
been significantly modified
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
PATH_MODE
DEFAULT_VAL_JOINER
dbg
DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS
DBG_TENSOR_VAL_JOINER
tensor_info
dbg_tensor
muutils.dbg
this code is based on an implementation of the Rust builtin
dbg!
for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has
been significantly modified
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
PATH_MODE: Literal['relative', 'absolute'] = 'relative'
DEFAULT_VAL_JOINER: str = ' = '
def dbg
(~_ExpType, muutils.dbg._NoExpPassedSentinel] = <muutils.dbg._NoExpPassedSentinel object>,
exp: Union[str]] = None,
formatter: Optional[Callable[[Any], str = ' = '
val_joiner: -> Union[~_ExpType, muutils.dbg._NoExpPassedSentinel] )
Call dbg with any variable or expression.
Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:
from <a href="">muutils.dbg</a> import dbg
a = 2
b = 5
dbg(a+b)
def square(x: int) -> int:
return x * x
dbg(square(a))
DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: Dict[str, Union[str, int, bool]] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': True, 'sparkline_bins': 7, 'sparkline_logy': False, 'colored': True, 'eq_char': '='}
DBG_TENSOR_VAL_JOINER: str = ': '
def tensor_info
-> str (tensor: Any)
dbg_tensor = functools.partial(<function dbg>, formatter=<function tensor_info>, val_joiner=': ')
docs for
muutils
v0.8.7
making working with dictionaries easier
DefaulterDict
: like a defaultdict, but default_factory
is passed the key as an argumentcondense_nested_dicts
: condense a nested dict, by
condensing numeric or matching keys with matching values to rangescondense_tensor_dict
: convert a dictionary of tensors
to a dictionary of shapeskwargs_to_nested_dict
: given kwargs from fire, convert
them to a nested dictDefaulterDict
defaultdict_to_dict_recursive
dotlist_to_nested_dict
nested_dict_to_dotlist
update_with_nested_dict
kwargs_to_nested_dict
is_numeric_consecutive
condense_nested_dicts_numeric_keys
condense_nested_dicts_matching_values
condense_nested_dicts
tuple_dims_replace
TensorDict
TensorIterable
TensorDictFormats
condense_tensor_dict
muutils.dictmagic
making working with dictionaries easier
DefaulterDict
: like a defaultdict, but default_factory
is passed the key as an argumentcondense_nested_dicts
: condense a nested dict, by
condensing numeric or matching keys with matching values to rangescondense_tensor_dict
: convert a dictionary of tensors
to a dictionary of shapeskwargs_to_nested_dict
: given kwargs from fire, convert
them to a nested dictclass DefaulterDict(typing.Dict[~_KT, ~_VT], typing.Generic[~_KT, ~_VT]):
like a defaultdict, but default_factory is passed the key as an argument
default_factory: Callable[[~_KT], ~_VT]
def defaultdict_to_dict_recursive
(
dd: Union[collections.defaultdict, muutils.dictmagic.DefaulterDict]-> dict )
Convert a defaultdict or DefaulterDict to a normal dict, recursively
def dotlist_to_nested_dict
str, Any], sep: str = '.') -> Dict[str, Any] (dot_dict: Dict[
Convert a dict with dot-separated keys to a nested dict
Example:
>>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
def nested_dict_to_dotlist
(str, Any],
nested_dict: Dict[str = '.',
sep: bool = False
allow_lists: -> dict[str, typing.Any] )
def update_with_nested_dict
(dict[str, typing.Any],
original: dict[str, typing.Any]
update: -> dict[str, typing.Any] )
Update a dict with a nested dict
Example: >>> update_with_nested_dict({‘a’: {‘b’: 1}, “c”: -1}, {‘a’: {“b”: 2}}) {‘a’: {‘b’: 2}, ‘c’: -1}
original: dict[str, Any]
the dict to update (will be
modified in-place)update: dict[str, Any]
the dict to update withdict
the updated dictdef kwargs_to_nested_dict
(dict[str, typing.Any],
kwargs_dict: str = '.',
sep: str] = None,
strip_prefix: Optional[str] = ErrorMode.Warn,
when_unknown_prefix: Union[muutils.errormode.ErrorMode, str], str]] = None
transform_key: Optional[Callable[[-> dict[str, typing.Any] )
given kwargs from fire, convert them to a nested dict
if strip_prefix is not None, then all keys must start with the
prefix. by default, will warn if an unknown prefix is found, but can be
set to raise an error or ignore it:
when_unknown_prefix: ErrorMode
Example:
def main(**kwargs):
print(kwargs_to_nested_dict(kwargs))
fire.Fire(main)
running the above script will give:
$ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
kwargs_dict: dict[str, Any]
the kwargs dict to
convertsep: str = "."
the separator to use for nested
keysstrip_prefix: Optional[str] = None
if not None, then
all keys must start with this prefixwhen_unknown_prefix: ErrorMode = ErrorMode.WARN
what to
do when an unknown prefix is foundtransform_key: Callable[[str], str] | None = None
a
function to apply to each key before adding it to the dict (applied
after stripping the prefix)def is_numeric_consecutive
list[str]) -> bool (lst:
Check if the list of keys is numeric and consecutive.
def condense_nested_dicts_numeric_keys
dict[str, typing.Any]) -> dict[str, typing.Any] (data:
condense a nested dict, by condensing numeric keys with matching values to ranges
>>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
'[1-3]': 1, '[4-6]': 2}
{>>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
"1": {"[1-2]": "a"}, "2": "b"} {
def condense_nested_dicts_matching_values
(dict[str, typing.Any],
data: = None
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] -> dict[str, typing.Any] )
condense a nested dict, by condensing keys with matching values
data : dict[str, Any]
data to processval_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable) (defaults to None
)def condense_nested_dicts
(dict[str, typing.Any],
data: bool = True,
condense_numeric_keys: bool = True,
condense_matching_values: = None
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] -> dict[str, typing.Any] )
condense a nested dict, by condensing numeric or matching keys with matching values to ranges
combines the functionality of
condense_nested_dicts_numeric_keys()
and
condense_nested_dicts_matching_values()
it’s not reversible because types are lost to make the printing pretty
data : dict[str, Any]
data to processcondense_numeric_keys : bool
whether to condense
numeric keys (e.g. “1”, “2”, “3”) to ranges (e.g. “[1-3]”) (defaults to
True
)condense_matching_values : bool
whether to condense
keys with matching values (defaults to True
)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable) (defaults to None
)def tuple_dims_replace
(tuple[int, ...],
t: dict[int, str]] = None
dims_names_map: Optional[-> tuple[typing.Union[int, str], ...] )
TensorDict = typing.Dict[str, ForwardRef('torch.Tensor|np.ndarray')]
TensorIterable = typing.Iterable[typing.Tuple[str, ForwardRef('torch.Tensor|np.ndarray')]]
TensorDictFormats = typing.Literal['dict', 'json', 'yaml', 'yml']
def condense_tensor_dict
('TensorDict | TensorIterable',
data: 'dict', 'json', 'yaml', 'yml'] = 'dict',
fmt: Literal[*args,
tuple], Any] = <function _default_shapes_convert>,
shapes_convert: Callable[[int = 0,
drop_batch_dims: str = '.',
sep: dict[int, str]] = None,
dims_names_map: Optional[bool = True,
condense_numeric_keys: bool = True,
condense_matching_values: = None,
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] 'dict', 'json', 'yaml', 'yml']] = None
return_format: Optional[Literal[-> Union[str, dict[str, str | tuple[int, ...]]] )
Convert a dictionary of tensors to a dictionary of shapes.
by default, values are converted to strings of their shapes (for nice
printing). If you want the actual shapes, set
shapes_convert = lambda x: x
or
shapes_convert = None
.
data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]
a either a TensorDict
dict from strings to tensors, or an
TensorIterable
iterable of (key, tensor) pairs (like you
might get from a dict().items())
)fmt : TensorDictFormats
format to return the result in
– either a dict, or dump to json/yaml directly for pretty printing. will
crash if yaml is not installed. (defaults to 'dict'
)shapes_convert : Callable[[tuple], Any]
conversion of a
shape tuple to a string or other format (defaults to turning it into a
string and removing quotes) (defaults to
lambdax:str(x).replace('"', '').replace("'", '')
)drop_batch_dims : int
number of leading dimensions to
drop from the shape (defaults to 0
)sep : str
separator to use for nested keys (defaults to
'.'
)dims_names_map : dict[int, str] | None
convert certain
dimension values in shape. not perfect, can be buggy (defaults to
None
)condense_numeric_keys : bool
whether to condense
numeric keys (e.g. “1”, “2”, “3”) to ranges (e.g. “[1-3]”), passed on to
condense_nested_dicts
(defaults to True
)condense_matching_values : bool
whether to condense
keys with matching values, passed on to
condense_nested_dicts
(defaults to True
)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable), passed on to condense_nested_dicts
(defaults
to None
)return_format : TensorDictFormats | None
legacy alias
for fmt
kwargstr|dict[str, str|tuple[int, ...]]
dict if
return_format='dict'
, a string for json
or
yaml
output>>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
>>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
embed:
W_E: (50257, 768)
pos_embed:
W_pos: (1024, 768)
blocks:
'[0-11]':
attn:
'[W_Q, W_K, W_V]': (12, 768, 64)
W_O: (12, 64, 768)
'[b_Q, b_K, b_V]': (12, 64)
b_O: (768,)
mlp:
W_in: (768, 3072)
b_in: (3072,)
W_out: (3072, 768)
b_out: (768,)
unembed:
W_U: (768, 50257)
b_U: (50257,)
ValueError
: if return_format
is not one
of ‘dict’, ‘json’, or ‘yaml’, or if you try to use ‘yaml’ output without
having PyYAML installeddocs for
muutils
v0.8.7
provides ErrorMode
enum for handling errors
consistently
pass an error_mode: ErrorMode
to a function to specify
how to handle a certain kind of exception. That function then instead of
raise
ing or warnings.warn
ing, calls
error_mode.process
with the message and the exception.
you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning.
WarningFunc
LoggingFunc
GLOBAL_WARN_FUNC
GLOBAL_LOG_FUNC
custom_showwarning
ErrorMode
ERROR_MODE_ALIASES
muutils.errormode
provides ErrorMode
enum for handling errors
consistently
pass an error_mode: ErrorMode
to a function to specify
how to handle a certain kind of exception. That function then instead of
raise
ing or warnings.warn
ing, calls
error_mode.process
with the message and the exception.
you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning.
class WarningFunc(typing.Protocol):
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
WarningFunc
*args, **kwargs) (
LoggingFunc = typing.Callable[[str], NoneType]
def GLOBAL_WARN_FUNC
(unknown)
Issue a warning, or maybe ignore it or raise an exception.
message Text of the warning message. category The Warning category subclass. Defaults to UserWarning. stacklevel How far up the call stack to make this warning appear. A value of 2 for example attributes the warning to the caller of the code calling warn(). source If supplied, the destroyed object which emitted a ResourceWarning skip_file_prefixes An optional tuple of module filename prefixes indicating frames to skip during stacklevel computations for stack frame attribution.
def GLOBAL_LOG_FUNC
*args, sep=' ', end='\n', file=None, flush=False) (
Prints the values to a stream, or to sys.stdout by default.
sep string inserted between values, default a space. end string appended after the last value, default a newline. file a file-like object (stream); defaults to the current sys.stdout. flush whether to forcibly flush the stream.
def custom_showwarning
(Warning | str,
message: Warning]] = None,
category: Optional[Type[str | None = None,
filename: int | None = None,
lineno: file: Optional[TextIO] = None,
str] = None
line: Optional[-> None )
class ErrorMode(enum.Enum):
Enum for handling errors consistently
pass one of the instances of this enum to a function to specify how to handle a certain kind of exception.
That function then instead of raise
ing or
warnings.warn
ing, calls error_mode.process
with the message and the exception.
EXCEPT = ErrorMode.Except
WARN = ErrorMode.Warn
LOG = ErrorMode.Log
IGNORE = ErrorMode.Ignore
def process
(self,
str,
msg: Exception] = <class 'ValueError'>,
except_cls: Type[Warning] = <class 'UserWarning'>,
warn_cls: Type[Exception] = None,
except_from: Optional[| None = None,
warn_func: muutils.errormode.WarningFunc str], NoneType]] = None
log_func: Optional[Callable[[ )
process an exception or warning according to the error mode
msg : str
message to pass to except_cls
or
warn_func
except_cls : typing.Type[Exception]
exception class to
raise, must be a subclass of Exception
(defaults to
ValueError
)warn_cls : typing.Type[Warning]
warning class to use,
must be a subclass of Warning
(defaults to
UserWarning
)except_from : typing.Optional[Exception]
will
raise except_cls(msg) from except_from
if not
None
(defaults to None
)warn_func : WarningFunc | None
function to use for
warnings, must have the signature
warn_func(msg: str, category: typing.Type[Warning], source: typing.Any = None) -> None
(defaults to None
)log_func : LoggingFunc | None
function to use for
logging, must have the signature
log_func(msg: str) -> None
(defaults to
None
)except_cls
: descriptionexcept_cls
: descriptionValueError
: descriptiondef from_any
(
cls,str | muutils.errormode.ErrorMode,
mode: bool = True,
allow_aliases: bool = True
allow_prefix: -> muutils.errormode.ErrorMode )
initialize an ErrorMode
from a string or an
ErrorMode
instance
def serialize
self) -> str (
def load
str) -> muutils.errormode.ErrorMode (cls, data:
ERROR_MODE_ALIASES: dict[str, muutils.errormode.ErrorMode] = {'except': ErrorMode.Except, 'warn': ErrorMode.Warn, 'log': ErrorMode.Log, 'ignore': ErrorMode.Ignore, 'e': ErrorMode.Except, 'error': ErrorMode.Except, 'err': ErrorMode.Except, 'raise': ErrorMode.Except, 'w': ErrorMode.Warn, 'warning': ErrorMode.Warn, 'l': ErrorMode.Log, 'print': ErrorMode.Log, 'output': ErrorMode.Log, 'show': ErrorMode.Log, 'display': ErrorMode.Log, 'i': ErrorMode.Ignore, 'silent': ErrorMode.Ignore, 'quiet': ErrorMode.Ignore, 'nothing': ErrorMode.Ignore}
map of string aliases to ErrorMode
instances
docs for
muutils
v0.8.7
group items by assuming that eq_func
defines an
equivalence relation
muutils.group_equiv
group items by assuming that eq_func
defines an
equivalence relation
def group_by_equivalence
(~T],
items_in: Sequence[~T, ~T], bool]
eq_func: Callable[[-> list[list[~T]] )
group items by assuming that eq_func
implies an
equivalence relation but might not be transitive
so, if f(a,b) and f(b,c) then f(a,c) might be false, but we still want to put [a,b,c] in the same class
note that lists are used to avoid the need for hashable items, and to allow for duplicates
items_in: Sequence[T]
the items to groupeq_func: Callable[[T, T], bool]
a function that returns
true if two items are equivalent. need not be transitivedocs for
muutils
v0.8.7
represents a mathematical Interval
over the real
numbers
muutils.interval
represents a mathematical Interval
over the real
numbers
Number = typing.Union[float, int]
class Interval:
Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
Interval
(*args: Union[Sequence[Union[float, int]], float, int],
bool] = None,
is_closed: Optional[bool] = None,
closed_L: Optional[bool] = None
closed_R: Optional[ )
lower: Union[float, int]
upper: Union[float, int]
closed_L: bool
closed_R: bool
singleton_set: Optional[set[Union[float, int]]]
is_closed: bool
is_open: bool
is_half_open: bool
is_singleton: bool
is_empty: bool
is_finite: bool
singleton: Union[float, int]
def get_empty
-> muutils.interval.Interval ()
def get_singleton
float, int]) -> muutils.interval.Interval (value: Union[
def numerical_contained
self, item: Union[float, int]) -> bool (
def interval_contained
self, item: muutils.interval.Interval) -> bool (
def from_str
str) -> muutils.interval.Interval (cls, input_str:
def copy
self) -> muutils.interval.Interval (
def size
self) -> float (
Returns the size of the interval.
float
the size of the intervaldef clamp
self, value: Union[int, float], epsilon: float = 1e-10) -> float (
Clamp the given value to the interval bounds.
For open bounds, the clamped value will be slightly inside the interval (by epsilon).
value : Union[int, float]
the value to clamp.epsilon : float
margin for open bounds (defaults to
_EPSILON
)float
the clamped valueValueError
: If the input value is NaN.def intersection
self, other: muutils.interval.Interval) -> muutils.interval.Interval (
def union
self, other: muutils.interval.Interval) -> muutils.interval.Interval (
class ClosedInterval(Interval):
Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
ClosedInterval
*args: Union[Sequence[float], float], **kwargs: Any) (
lower
upper
closed_L
closed_R
singleton_set
is_closed
is_open
is_half_open
is_singleton
is_empty
is_finite
singleton
get_empty
get_singleton
numerical_contained
interval_contained
from_str
copy
size
clamp
intersection
union
class OpenInterval(Interval):
Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
OpenInterval
*args: Union[Sequence[float], float], **kwargs: Any) (
lower
upper
closed_L
closed_R
singleton_set
is_closed
is_open
is_half_open
is_singleton
is_empty
is_finite
singleton
get_empty
get_singleton
numerical_contained
interval_contained
from_str
copy
size
clamp
intersection
union
docs for
muutils
v0.8.7
submodule for serializing things to json in a recoverable way
you can throw any object into
muutils.json_serialize.json_serialize
and it will return a
JSONitem
, meaning a bool, int, float, str, None, list of
JSONitem
s, or a dict mappting to JSONitem
.
The goal of this is if you want to just be able to store something as
relatively human-readable JSON, and don’t care as much about recovering
it, you can throw it into json_serialize
and it will just
work. If you want to do so in a recoverable way, check out ZANJ
.
it will do so by looking in DEFAULT_HANDLERS
, which will
keep it as-is if its already valid, then try to find a
.serialize()
method on the object, and then have a bunch of
special cases. You can add handlers by initializing a
JsonSerializer
object and passing a sequence of them to
handlers_pre
additionally, SerializeableDataclass
is a special kind
of dataclass where you specify how to serialize each field, and a
.serialize()
method is automatically added to the class.
This is done by using the serializable_dataclass
decorator,
inheriting from SerializeableDataclass
, and
serializable_field
in place of
dataclasses.field
when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ
library,
which extends this to support saving things to disk in a more efficient
way than just plain json (arrays are saved as npy files, for example),
and automatically detecting how to load saved objects into their
original classes.
json_serialize
serializable_dataclass
serializable_field
arr_metadata
load_array
BASE_HANDLERS
JSONitem
JsonSerializer
try_catch
dc_eq
SerializableDataclass
muutils.json_serialize
submodule for serializing things to json in a recoverable way
you can throw any object into
<a href="json_serialize/json_serialize.html">muutils.json_serialize.json_serialize</a>
and it will return a JSONitem
, meaning a bool, int, float,
str, None, list of JSONitem
s, or a dict mappting to
JSONitem
.
The goal of this is if you want to just be able to store something as
relatively human-readable JSON, and don’t care as much about recovering
it, you can throw it into json_serialize
and it will just
work. If you want to do so in a recoverable way, check out ZANJ
.
it will do so by looking in DEFAULT_HANDLERS
, which will
keep it as-is if its already valid, then try to find a
.serialize()
method on the object, and then have a bunch of
special cases. You can add handlers by initializing a
JsonSerializer
object and passing a sequence of them to
handlers_pre
additionally, SerializeableDataclass
is a special kind
of dataclass where you specify how to serialize each field, and a
.serialize()
method is automatically added to the class.
This is done by using the serializable_dataclass
decorator,
inheriting from SerializeableDataclass
, and
serializable_field
in place of
dataclasses.field
when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ
library,
which extends this to support saving things to disk in a more efficient
way than just plain json (arrays are saved as npy files, for example),
and automatically detecting how to load saved objects into their
original classes.
def json_serialize
(
obj: Any,tuple[typing.Union[str, int], ...] = ()
path: -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] )
serialize object to json-serializable object with default config
def serializable_dataclass
(=None,
_cls*,
bool = True,
init: repr: bool = True,
bool = True,
eq: bool = False,
order: bool = False,
unsafe_hash: bool = False,
frozen: list[str]] = None,
properties_to_serialize: Optional[bool = True,
register_handler: = ErrorMode.Except,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Warn,
on_typecheck_mismatch: muutils.errormode.ErrorMode list[str] | None = None,
methods_no_override: **kwargs
)
decorator to make a dataclass serializable. must also make it
inherit from SerializableDataclass
!!
types will be validated (like pydantic) unless
on_typecheck_mismatch
is set to
ErrorMode.IGNORE
behavior of most kwargs matches that of
dataclasses.dataclass
, but with some additional kwargs. any
kwargs not listed here are passed to
dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine
fields.
If init is true, an __init__()
method is added to the
class. If repr is true, a __repr__()
method is added. If
order is true, rich comparison dunder methods are added. If unsafe_hash
is true, a __hash__()
method function is added. If frozen
is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
int
a: str b:
>>> Myclass(a=1, b="q").serialize()
'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} {_FORMAT_KEY:
_cls : _type_
class to decorate. don’t pass this arg,
just use this as a decorator (defaults to None
)init : bool
whether to add an __init__
method (passed to dataclasses.dataclass) (defaults to
True
)repr : bool
whether to add a __repr__
method (passed to dataclasses.dataclass) (defaults to
True
)order : bool
whether to add rich comparison methods
(passed to dataclasses.dataclass) (defaults to
False
)unsafe_hash : bool
whether to add a
__hash__
method (passed to dataclasses.dataclass)
(defaults to False
)frozen : bool
whether to make the class frozen
(passed to dataclasses.dataclass) (defaults to
False
)properties_to_serialize : Optional[list[str]]
which
properties to add to the serialized data dict
SerializableDataclass only (defaults to
None
)register_handler : bool
if true, register the class
with ZANJ for loading SerializableDataclass only
(defaults to True
)on_typecheck_error : ErrorMode
what to do if type
checking throws an exception (except, warn, ignore). If
ignore
and an exception is thrown, type validation will
still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode
what to do if a type
mismatch is found (except, warn, ignore). If ignore
, type
validation will return True
SerializableDataclass
onlymethods_no_override : list[str]|None
list of methods
that should not be overridden by the decorator by default,
__eq__
, serialize
, load
, and
validate_fields_types
are overridden by this function, but
you can disable this if you’d rather write your own.
dataclasses.dataclass
might still overwrite these, and
those options take precedence SerializableDataclass
only (defaults to None
)**kwargs
(passed to
dataclasses.dataclass)_type_
the decorated classKWOnlyError
: only raised if kw_only
is
True
and python version is <3.9, since
dataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not a
SerializableField
FieldSerializationError
: if there is an error
serializing a fieldAttributeError
: if a property is not found on the
classFieldLoadingError
: if there is an error loading a
fielddef serializable_field
(*_args,
= <dataclasses._MISSING_TYPE object>,
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Any, dataclasses._MISSING_TYPE] bool = True,
init: repr: bool = True,
hash: Optional[bool] = None,
bool = True,
compare: = None,
metadata: Optional[mappingproxy] bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
kw_only: Union[bool = True,
serialize: = None,
serialization_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] bool = True,
assert_type: type], bool]] = None,
custom_typecheck_fn: Optional[Callable[[**kwargs: Any
-> Any )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
### ----------------------------------------------------------------------
### new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
serialize
: whether to serialize this field when
serializing the class’serialization_fn
: function taking the instance of the
field and returning a serializable object. If not provided, will iterate
through the SerializerHandler
s defined in
<a href="json_serialize/json_serialize.html">muutils.json_serialize.json_serialize</a>
loading_fn
: function taking the serialized object and
returning the instance of the field. If not provided, will take object
as-is.deserialize_fn
: new alternative to
loading_fn
. takes only the field’s value, not the whole
class. if both loading_fn
and deserialize_fn
are provided, an error will be raised.assert_type
: whether to assert the type of the field
when loading. if False
, will not check the type of the
field.custom_typecheck_fn
: function taking the type of the
field and returning whether the type itself is valid. if not provided,
will use the default type checking.loading_fn
takes the dict of the
class, not the field. if you wanted a
loading_fn
that does nothing, you’d write:class MyClass:
int = serializable_field(
my_field: =lambda x: str(x),
serialization_fn=lambda x["my_field"]: int(x)
loading_fn )
using deserialize_fn
instead:
class MyClass:
int = serializable_field(
my_field: =lambda x: str(x),
serialization_fn=lambda x: int(x)
deserialize_fn )
In the above code, my_field
is an int but will be
serialized as a string.
note that if not using ZANJ, and you have a class inside a container,
you MUST provide serialization_fn
and
loading_fn
to serialize and load the container. ZANJ will
automatically do this for you.
custom_value_check_fn
: function taking the value of the
field and returning whether the value itself is valid. if not provided,
any value is valid as long as it passes the type testdef arr_metadata
-> dict[str, list[int] | str | int] (arr)
get metadata for a numpy array
def load_array
(bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
arr: Union['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
array_mode: Optional[Literal[-> Any )
load a json-serialized array, infer the mode if not specified
BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]
class JsonSerializer:
Json serialization class (holds configs)
array_mode : ArrayMode
how to write arrays (defaults to
"array_list_meta"
)error_mode : ErrorMode
what to do when we can’t
serialize an object (will use repr as fallback if “ignore” or “warn”)
(defaults to "except"
)handlers_pre : MonoTuple[SerializerHandler]
handlers to
use before the default handlers (defaults to tuple()
)handlers_default : MonoTuple[SerializerHandler]
default
handlers to use (defaults to DEFAULT_HANDLERS
)write_only_format : bool
changes _FORMAT_KEY keys in
output to “write_format” (when you want to serialize
something in a way that zanj won’t try to recover the object when
loading) (defaults to False
)ValueError
: on init, if args
is not
emptySerializationException
: on
json_serialize()
, if any error occurs when trying to
serialize an object and error_mode
is set to
ErrorMode.EXCEPT"
JsonSerializer
(*args,
'list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta',
array_mode: Literal[= ErrorMode.Except,
error_mode: muutils.errormode.ErrorMode None = (),
handlers_pre: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')),
handlers_default: bool = False
write_only_format: )
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
error_mode: muutils.errormode.ErrorMode
write_only_format: bool
handlers: None
def json_serialize
(self,
obj: Any,tuple[typing.Union[str, int], ...] = ()
path: -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] )
def hashify
(self,
obj: Any,tuple[typing.Union[str, int], ...] = (),
path: bool = True
force: -> Union[bool, int, float, str, tuple] )
try to turn any object into something hashable
def try_catch
(func: Callable)
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
def dc_eq
(
dc1,
dc2,bool = False,
except_when_class_mismatch: bool = True,
false_when_class_mismatch: bool = False
except_when_field_mismatch: -> bool )
checks if two dataclasses which (might) hold numpy arrays are equal
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
if True
,
will throw TypeError
if the classes are different. if not,
will return false by default or attempt to compare the fields if
false_when_class_mismatch
is False
(default:
False
)false_when_class_mismatch: bool
only relevant if
except_when_class_mismatch
is False
. if
True
, will return False
if the classes are
different. if False
, will attempt to compare the
fields.except_when_field_mismatch: bool
only relevant if
except_when_class_mismatch
is False
and
false_when_class_mismatch
is False
. if
True
, will throw TypeError
if the fields are
different. (default: True
)bool
: True if the dataclasses are equal, False
otherwiseTypeError
: if the dataclasses are of different
classesAttributeError
: if the dataclasses have different
fields [START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
class SerializableDataclass(abc.ABC):
Base class for serializable dataclasses
only for linting and type checking, still need to call
serializable_dataclass
decorator
@serializable_dataclass
class MyClass(SerializableDataclass):
int
a: str b:
and then you can call my_obj.serialize()
to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
str
x:
y: MyClass= serializable_field(
act_fun: torch.nn.Module =torch.nn.ReLU(),
default=lambda x: str(x),
serialization_fn=lambda x: getattr(torch.nn, x)(),
deserialize_fn )
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize
self) -> dict[str, typing.Any] (
returns the class as a dict, implemented by using
@serializable_dataclass
decorator
def load
~T], data: Union[dict[str, Any], ~T]) -> ~T (cls: Type[
takes in an appropriately structured dict and returns an instance of
the class, implemented by using @serializable_dataclass
decorator
def validate_fields_types
(self,
= ErrorMode.Except
on_typecheck_error: muutils.errormode.ErrorMode -> bool )
validate the types of all the fields on a
SerializableDataclass
. calls
SerializableDataclass__validate_field_type
for each
field
def validate_field_type
(self,
| str,
field: muutils.json_serialize.serializable_field.SerializableField = ErrorMode.Except
on_typecheck_error: muutils.errormode.ErrorMode -> bool )
given a dataclass, check the field matches the type hint
def diff
(self,
other: muutils.json_serialize.serializable_dataclass.SerializableDataclass,bool = False
of_serialized: -> dict[str, typing.Any] )
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
'b': {'self': 2, 'other': 3}}
{>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} {
other : SerializableDataclass
other instance to compare
againstof_serialized : bool
if true, compare serialized data
and not raw values (defaults to False
)dict[str, Any]
ValueError
: if the instances are not of the same
typeValueError
: if the instances are
dataclasses.dataclass
but not
SerializableDataclass
def update_from_nested_dict
self, nested_dict: dict[str, typing.Any]) (
update the instance from a nested dict, useful for configuration from command line args
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
docs for
muutils
v0.8.7
this utilities module handles serialization and loading of numpy and torch arrays as json
array_list_meta
is less efficient (arrays are stored as
nested lists), but preserves both metadata and human readability.array_b64_meta
is the most efficient, but is not human
readable.external
is mostly for use in ZANJ
muutils.json_serialize.array
this utilities module handles serialization and loading of numpy and torch arrays as json
array_list_meta
is less efficient (arrays are stored as
nested lists), but preserves both metadata and human readability.array_b64_meta
is the most efficient, but is not human
readable.external
is mostly for use in ZANJ
ArrayMode = typing.Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
def array_n_elements
-> int (arr)
get the number of elements in an array
def arr_metadata
-> dict[str, list[int] | str | int] (arr)
get metadata for a numpy array
def serialize_array
("'JsonSerializer'",
jser:
arr: numpy.ndarray,str, Sequence[str | int]],
path: Union['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
array_mode: Optional[Literal[-> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] )
serialize a numpy or pytorch array in one of several modes
if the object is zero-dimensional, simply get the unique item
array_mode: ArrayMode
can be one of: -
list
: serialize as a list of values, no metadata
(equivalent to arr.tolist()
) -
array_list_meta
: serialize dict with metadata, actual list
under the key data
- array_hex_meta
: serialize
dict with metadata, actual hex string under the key data
-
array_b64_meta
: serialize dict with metadata, actual base64
string under the key data
for array_list_meta
, array_hex_meta
, and
array_b64_meta
, the serialized object is:
{
_FORMAT_KEY: <array_list_meta|array_hex_meta>,
"shape": arr.shape,
"dtype": str(arr.dtype),
"data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
}
arr : Any
array to serializearray_mode : ArrayMode
mode in which to serialize the
array (defaults to None
and inheriting from
jser: JsonSerializer
)JSONitem
json serialized arrayKeyError
: if the array mode is not validdef infer_array_mode
(bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]
arr: Union[-> Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] )
given a serialized array, infer the mode
assumes the array was serialized via
serialize_array()
def load_array
(bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
arr: Union['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
array_mode: Optional[Literal[-> Any )
load a json-serialized array, infer the mode if not specified
docs for
muutils
v0.8.7
provides the basic framework for json serialization of objects
notably:
SerializerHandler
defines how to serialize a specific
type of objectJsonSerializer
handles configuration for which handlers
to usejson_serialize
provides the default configuration if
you don’t care – call it on any object!SERIALIZER_SPECIAL_KEYS
SERIALIZER_SPECIAL_FUNCS
SERIALIZE_DIRECT_AS_STR
ObjectPath
SerializerHandler
BASE_HANDLERS
DEFAULT_HANDLERS
JsonSerializer
GLOBAL_JSON_SERIALIZER
json_serialize
muutils.json_serialize.json_serialize
provides the basic framework for json serialization of objects
notably:
SerializerHandler
defines how to serialize a specific
type of objectJsonSerializer
handles configuration for which handlers
to usejson_serialize
provides the default configuration if
you don’t care – call it on any object!SERIALIZER_SPECIAL_KEYS: None = ('__name__', '__doc__', '__module__', '__class__', '__dict__', '__annotations__')
SERIALIZER_SPECIAL_FUNCS: dict[str, typing.Callable] = {'str': <class 'str'>, 'dir': <built-in function dir>, 'type': <function <lambda>>, 'repr': <function <lambda>>, 'code': <function <lambda>>, 'sourcefile': <function <lambda>>}
SERIALIZE_DIRECT_AS_STR: Set[str] = {"<class 'torch.device'>", "<class 'torch.dtype'>"}
ObjectPath = tuple[typing.Union[str, int], ...]
class SerializerHandler:
a handler for a specific type of object
- `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
- `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
- `desc : str` description of the handler (optional)
SerializerHandler
(tuple[Union[str, int], ...]], bool],
check: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]],
serialize_func: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, str,
uid: str
desc: )
check: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], bool]
serialize_func: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]
uid: str
desc: str
def serialize
self) -> dict (
serialize the handler info
BASE_HANDLERS: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
DEFAULT_HANDLERS: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in
SERIALIZE_DIRECT_AS_STRto strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings'))
class JsonSerializer:
Json serialization class (holds configs)
array_mode : ArrayMode
how to write arrays (defaults to
"array_list_meta"
)error_mode : ErrorMode
what to do when we can’t
serialize an object (will use repr as fallback if “ignore” or “warn”)
(defaults to "except"
)handlers_pre : MonoTuple[SerializerHandler]
handlers to
use before the default handlers (defaults to tuple()
)handlers_default : MonoTuple[SerializerHandler]
default
handlers to use (defaults to DEFAULT_HANDLERS
)write_only_format : bool
changes _FORMAT_KEY keys in
output to “write_format” (when you want to serialize
something in a way that zanj won’t try to recover the object when
loading) (defaults to False
)ValueError
: on init, if args
is not
emptySerializationException
: on
json_serialize()
, if any error occurs when trying to
serialize an object and error_mode
is set to
ErrorMode.EXCEPT"
JsonSerializer
(*args,
'list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta',
array_mode: Literal[= ErrorMode.Except,
error_mode: muutils.errormode.ErrorMode None = (),
handlers_pre: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')),
handlers_default: bool = False
write_only_format: )
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
error_mode: muutils.errormode.ErrorMode
write_only_format: bool
handlers: None
def json_serialize
(self,
obj: Any,tuple[typing.Union[str, int], ...] = ()
path: -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] )
def hashify
(self,
obj: Any,tuple[typing.Union[str, int], ...] = (),
path: bool = True
force: -> Union[bool, int, float, str, tuple] )
try to turn any object into something hashable
GLOBAL_JSON_SERIALIZER: muutils.json_serialize.json_serialize.JsonSerializer = <muutils.json_serialize.json_serialize.JsonSerializer object>
def json_serialize
(
obj: Any,tuple[typing.Union[str, int], ...] = ()
path: -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] )
serialize object to json-serializable object with default config
docs for
muutils
v0.8.7
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj)
will give you a dict, but
if some fields are not json-serializable, you will get an error when you
call json.dumps(d)
. This module provides a way around
that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
int
a: str b:
and then you can call my_obj.serialize()
to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
str
x:
y: MyClass= serializable_field(
act_fun: torch.nn.Module =torch.nn.ReLU(),
default=lambda x: str(x),
serialization_fn=lambda x: getattr(torch.nn, x)(),
deserialize_fn )
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
CantGetTypeHintsWarning
ZanjMissingWarning
zanj_register_loader_serializable_dataclass
FieldIsNotInitOrSerializeWarning
SerializableDataclass__validate_field_type
SerializableDataclass__validate_fields_types__dict
SerializableDataclass__validate_fields_types
SerializableDataclass
get_cls_type_hints_cached
get_cls_type_hints
KWOnlyError
FieldError
NotSerializableFieldException
FieldSerializationError
FieldLoadingError
FieldTypeMismatchError
serializable_dataclass
muutils.json_serialize.serializable_dataclass
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj)
will give you a dict, but
if some fields are not json-serializable, you will get an error when you
call json.dumps(d)
. This module provides a way around
that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
int
a: str b:
and then you can call my_obj.serialize()
to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
str
x:
y: MyClass= serializable_field(
act_fun: torch.nn.Module =torch.nn.ReLU(),
default=lambda x: str(x),
serialization_fn=lambda x: getattr(torch.nn, x)(),
deserialize_fn )
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
class CantGetTypeHintsWarning(builtins.UserWarning):
special warning for when we can’t get type hints
class ZanjMissingWarning(builtins.UserWarning):
special warning for when ZANJ
is missing
– register_loader_serializable_dataclass
will not work
def zanj_register_loader_serializable_dataclass
~T]) (cls: Type[
Register a serializable dataclass with the ZANJ import
this allows ZANJ().read()
to load the class and not just
return plain dicts
class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):
Base class for warnings generated by user code.
def SerializableDataclass__validate_field_type
(self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
| str,
field: muutils.json_serialize.serializable_field.SerializableField = ErrorMode.Except
on_typecheck_error: muutils.errormode.ErrorMode -> bool )
given a dataclass, check the field matches the type hint
this function is written to
<a href="#SerializableDataclass.validate_field_type">SerializableDataclass.validate_field_type</a>
self : SerializableDataclass
SerializableDataclass
instancefield : SerializableField | str
field to validate, will
get from self.__dataclass_fields__
if an
str
on_typecheck_error : ErrorMode
what to do if type
checking throws an exception (except, warn, ignore). If
ignore
and an exception is thrown, the function will return
False
(defaults to
_DEFAULT_ON_TYPECHECK_ERROR
)bool
if the field type is correct. False
if the field type is incorrect or an exception is thrown and
on_typecheck_error
is ignore
def SerializableDataclass__validate_fields_types__dict
(self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
= ErrorMode.Except
on_typecheck_error: muutils.errormode.ErrorMode -> dict[str, bool] )
validate the types of all the fields on a
SerializableDataclass
. calls
SerializableDataclass__validate_field_type
for each
field
returns a dict of field names to bools, where the bool is if the field type is valid
def SerializableDataclass__validate_fields_types
(self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
= ErrorMode.Except
on_typecheck_error: muutils.errormode.ErrorMode -> bool )
validate the types of all the fields on a
SerializableDataclass
. calls
SerializableDataclass__validate_field_type
for each
field
class SerializableDataclass(abc.ABC):
Base class for serializable dataclasses
only for linting and type checking, still need to call
serializable_dataclass
decorator
@serializable_dataclass
class MyClass(SerializableDataclass):
int
a: str b:
and then you can call my_obj.serialize()
to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
str
x:
y: MyClass= serializable_field(
act_fun: torch.nn.Module =torch.nn.ReLU(),
default=lambda x: str(x),
serialization_fn=lambda x: getattr(torch.nn, x)(),
deserialize_fn )
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize
self) -> dict[str, typing.Any] (
returns the class as a dict, implemented by using
@serializable_dataclass
decorator
def load
~T], data: Union[dict[str, Any], ~T]) -> ~T (cls: Type[
takes in an appropriately structured dict and returns an instance of
the class, implemented by using @serializable_dataclass
decorator
def validate_fields_types
(self,
= ErrorMode.Except
on_typecheck_error: muutils.errormode.ErrorMode -> bool )
validate the types of all the fields on a
SerializableDataclass
. calls
SerializableDataclass__validate_field_type
for each
field
def validate_field_type
(self,
| str,
field: muutils.json_serialize.serializable_field.SerializableField = ErrorMode.Except
on_typecheck_error: muutils.errormode.ErrorMode -> bool )
given a dataclass, check the field matches the type hint
def diff
(self,
other: muutils.json_serialize.serializable_dataclass.SerializableDataclass,bool = False
of_serialized: -> dict[str, typing.Any] )
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
'b': {'self': 2, 'other': 3}}
{>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} {
other : SerializableDataclass
other instance to compare
againstof_serialized : bool
if true, compare serialized data
and not raw values (defaults to False
)dict[str, Any]
ValueError
: if the instances are not of the same
typeValueError
: if the instances are
dataclasses.dataclass
but not
SerializableDataclass
def update_from_nested_dict
self, nested_dict: dict[str, typing.Any]) (
update the instance from a nested dict, useful for configuration from command line args
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
def get_cls_type_hints_cached
~T]) -> dict[str, typing.Any] (cls: Type[
cached typing.get_type_hints for a class
def get_cls_type_hints
~T]) -> dict[str, typing.Any] (cls: Type[
helper function to get type hints for a class
class KWOnlyError(builtins.NotImplementedError):
kw-only dataclasses are not supported in python <3.9
class FieldError(builtins.ValueError):
base class for field errors
class NotSerializableFieldException(FieldError):
field is not a SerializableField
class FieldSerializationError(FieldError):
error while serializing a field
class FieldLoadingError(FieldError):
error while loading a field
class FieldTypeMismatchError(FieldError, builtins.TypeError):
error when a field type does not match the type hint
def serializable_dataclass
(=None,
_cls*,
bool = True,
init: repr: bool = True,
bool = True,
eq: bool = False,
order: bool = False,
unsafe_hash: bool = False,
frozen: list[str]] = None,
properties_to_serialize: Optional[bool = True,
register_handler: = ErrorMode.Except,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Warn,
on_typecheck_mismatch: muutils.errormode.ErrorMode list[str] | None = None,
methods_no_override: **kwargs
)
decorator to make a dataclass serializable. must also make it
inherit from SerializableDataclass
!!
types will be validated (like pydantic) unless
on_typecheck_mismatch
is set to
ErrorMode.IGNORE
behavior of most kwargs matches that of
dataclasses.dataclass
, but with some additional kwargs. any
kwargs not listed here are passed to
dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine
fields.
If init is true, an __init__()
method is added to the
class. If repr is true, a __repr__()
method is added. If
order is true, rich comparison dunder methods are added. If unsafe_hash
is true, a __hash__()
method function is added. If frozen
is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
int
a: str b:
>>> Myclass(a=1, b="q").serialize()
'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} {_FORMAT_KEY:
_cls : _type_
class to decorate. don’t pass this arg,
just use this as a decorator (defaults to None
)init : bool
whether to add an __init__
method (passed to dataclasses.dataclass) (defaults to
True
)repr : bool
whether to add a __repr__
method (passed to dataclasses.dataclass) (defaults to
True
)order : bool
whether to add rich comparison methods
(passed to dataclasses.dataclass) (defaults to
False
)unsafe_hash : bool
whether to add a
__hash__
method (passed to dataclasses.dataclass)
(defaults to False
)frozen : bool
whether to make the class frozen
(passed to dataclasses.dataclass) (defaults to
False
)properties_to_serialize : Optional[list[str]]
which
properties to add to the serialized data dict
SerializableDataclass only (defaults to
None
)register_handler : bool
if true, register the class
with ZANJ for loading SerializableDataclass only
(defaults to True
)on_typecheck_error : ErrorMode
what to do if type
checking throws an exception (except, warn, ignore). If
ignore
and an exception is thrown, type validation will
still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode
what to do if a type
mismatch is found (except, warn, ignore). If ignore
, type
validation will return True
SerializableDataclass
onlymethods_no_override : list[str]|None
list of methods
that should not be overridden by the decorator by default,
__eq__
, serialize
, load
, and
validate_fields_types
are overridden by this function, but
you can disable this if you’d rather write your own.
dataclasses.dataclass
might still overwrite these, and
those options take precedence SerializableDataclass
only (defaults to None
)**kwargs
(passed to
dataclasses.dataclass)_type_
the decorated classKWOnlyError
: only raised if kw_only
is
True
and python version is <3.9, since
dataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not a
SerializableField
FieldSerializationError
: if there is an error
serializing a fieldAttributeError
: if a property is not found on the
classFieldLoadingError
: if there is an error loading a
fielddocs for
muutils
v0.8.7
extends dataclasses.Field
for use with
SerializableDataclass
In particular, instead of using dataclasses.field
, use
serializable_field
to define fields in a
SerializableDataclass
. You provide information on how the
field should be serialized and loaded (as well as anything that goes
into dataclasses.field
) when you define the field, and the
SerializableDataclass
will automatically use those
functions.
muutils.json_serialize.serializable_field
extends dataclasses.Field
for use with
SerializableDataclass
In particular, instead of using dataclasses.field
, use
serializable_field
to define fields in a
SerializableDataclass
. You provide information on how the
field should be serialized and loaded (as well as anything that goes
into dataclasses.field
) when you define the field, and the
SerializableDataclass
will automatically use those
functions.
class SerializableField(dataclasses.Field):
extension of dataclasses.Field
with additional
serialization properties
SerializableField
(= <dataclasses._MISSING_TYPE object>,
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Callable[[], Any], dataclasses._MISSING_TYPE] bool = True,
init: repr: bool = True,
hash: Optional[bool] = None,
bool = True,
compare: = None,
metadata: Optional[mappingproxy] bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
kw_only: Union[bool = True,
serialize: = None,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] bool = True,
assert_type: <member 'type' of 'SerializableField' objects>], bool]] = None
custom_typecheck_fn: Optional[Callable[[ )
serialize: bool
serialization_fn: Optional[Callable[[Any], Any]]
loading_fn: Optional[Callable[[Any], Any]]
deserialize_fn: Optional[Callable[[Any], Any]]
assert_type: bool
custom_typecheck_fn: Optional[Callable[[<member 'type' of 'SerializableField' objects>], bool]]
def from_Field
(
cls,
field: dataclasses.Field-> muutils.json_serialize.serializable_field.SerializableField )
copy all values from a dataclasses.Field
to new
SerializableField
name
type
default
default_factory
init
repr
hash
compare
metadata
kw_only
def serializable_field
(*_args,
= <dataclasses._MISSING_TYPE object>,
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Any, dataclasses._MISSING_TYPE] bool = True,
init: repr: bool = True,
hash: Optional[bool] = None,
bool = True,
compare: = None,
metadata: Optional[mappingproxy] bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
kw_only: Union[bool = True,
serialize: = None,
serialization_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] bool = True,
assert_type: type], bool]] = None,
custom_typecheck_fn: Optional[Callable[[**kwargs: Any
-> Any )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
### ----------------------------------------------------------------------
### new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
serialize
: whether to serialize this field when
serializing the class’serialization_fn
: function taking the instance of the
field and returning a serializable object. If not provided, will iterate
through the SerializerHandler
s defined in
<a href="json_serialize.html">muutils.json_serialize.json_serialize</a>
loading_fn
: function taking the serialized object and
returning the instance of the field. If not provided, will take object
as-is.deserialize_fn
: new alternative to
loading_fn
. takes only the field’s value, not the whole
class. if both loading_fn
and deserialize_fn
are provided, an error will be raised.assert_type
: whether to assert the type of the field
when loading. if False
, will not check the type of the
field.custom_typecheck_fn
: function taking the type of the
field and returning whether the type itself is valid. if not provided,
will use the default type checking.loading_fn
takes the dict of the
class, not the field. if you wanted a
loading_fn
that does nothing, you’d write:class MyClass:
int = serializable_field(
my_field: =lambda x: str(x),
serialization_fn=lambda x["my_field"]: int(x)
loading_fn )
using deserialize_fn
instead:
class MyClass:
int = serializable_field(
my_field: =lambda x: str(x),
serialization_fn=lambda x: int(x)
deserialize_fn )
In the above code, my_field
is an int but will be
serialized as a string.
note that if not using ZANJ, and you have a class inside a container,
you MUST provide serialization_fn
and
loading_fn
to serialize and load the container. ZANJ will
automatically do this for you.
custom_value_check_fn
: function taking the value of the
field and returning whether the value itself is valid. if not provided,
any value is valid as long as it passes the type testdocs for
muutils
v0.8.7
utilities for json_serialize
BaseType
JSONitem
JSONdict
Hashableitem
UniversalContainer
isinstance_namedtuple
try_catch
SerializationException
string_as_lines
safe_getsource
array_safe_eq
dc_eq
MonoTuple
muutils.json_serialize.util
utilities for json_serialize
BaseType = typing.Union[bool, int, float, str, NoneType]
JSONitem = typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]
JSONdict = typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]]
Hashableitem = typing.Union[bool, int, float, str, tuple]
class UniversalContainer:
contains everything – x in UniversalContainer()
is
always True
def isinstance_namedtuple
-> bool (x: Any)
checks if x
is a namedtuple
credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
def try_catch
(func: Callable)
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
class SerializationException(builtins.Exception):
Common base class for all non-exit exceptions.
def string_as_lines
str | None) -> list[str] (s:
for easier reading of long strings in json, split up by newlines
sort of like how jupyter notebooks do it
def safe_getsource
-> list[str] (func)
def array_safe_eq
-> bool (a: Any, b: Any)
check if two objects are equal, account for if numpy arrays or torch tensors
def dc_eq
(
dc1,
dc2,bool = False,
except_when_class_mismatch: bool = True,
false_when_class_mismatch: bool = False
except_when_field_mismatch: -> bool )
checks if two dataclasses which (might) hold numpy arrays are equal
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
if True
,
will throw TypeError
if the classes are different. if not,
will return false by default or attempt to compare the fields if
false_when_class_mismatch
is False
(default:
False
)false_when_class_mismatch: bool
only relevant if
except_when_class_mismatch
is False
. if
True
, will return False
if the classes are
different. if False
, will attempt to compare the
fields.except_when_field_mismatch: bool
only relevant if
except_when_class_mismatch
is False
and
false_when_class_mismatch
is False
. if
True
, will throw TypeError
if the fields are
different. (default: True
)bool
: True if the dataclasses are equal, False
otherwiseTypeError
: if the dataclasses are of different
classesAttributeError
: if the dataclasses have different
fields [START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
class MonoTuple:
tuple type hint, but for a tuple of any length with all the same type
docs for
muutils
v0.8.7
utilities for reading and writing jsonlines files, including gzip support
muutils.jsonlines
utilities for reading and writing jsonlines files, including gzip support
def jsonl_load
(str,
path: /,
*,
bool | None = None
use_gzip: -> list[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]] )
def jsonl_load_log
str, /, *, use_gzip: bool | None = None) -> list[dict] (path:
def jsonl_write
(str,
path: bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]],
items: Sequence[Union[bool | None = None,
use_gzip: int = 2
gzip_compresslevel: -> None )
docs for
muutils
v0.8.7
anonymous getitem class
util for constructing a class which has a getitem method which just calls a function
a lambda
is an anonymous function: kappa is the letter
before lambda in the greek alphabet, hence the name of this class
muutils.kappa
anonymous getitem class
util for constructing a class which has a getitem method which just calls a function
a lambda
is an anonymous function: kappa is the letter
before lambda in the greek alphabet, hence the name of this class
class Kappa(typing.Mapping[~_kappa_K, ~_kappa_V]):
A Mapping is a generic container for associating key/value pairs.
This class provides concrete generic implementations of all methods except for getitem, iter, and len.
Kappa
~_kappa_K], ~_kappa_V]) (func_getitem: Callable[[
func_getitem
doc
docs for
muutils
v0.8.7
(deprecated) experimenting with logging utilities
muutils.logger
(deprecated) experimenting with logging utilities
class Logger(muutils.logger.simplelogger.SimpleLogger):
logger with more features, including log levels and streams
- `log_path : str | None`
default log file path
(defaults to `None`)
- `log_file : AnyIO | None`
default log io, should have a `.write()` method (pass only this or `log_path`, not both)
(defaults to `None`)
- `timestamp : bool`
whether to add timestamps to every log message (under the `_timestamp` key)
(defaults to `True`)
- `default_level : int`
default log level for streams/messages that don't specify a level
(defaults to `0`)
- `console_print_threshold : int`
log level at which to print to the console, anything greater will not be printed unless overridden by `console_print`
(defaults to `50`)
- `level_header : HeaderFunction`
function for formatting log messages when printing to console
(defaults to `HEADER_FUNCTIONS["md"]`)
keep_last_msg_time : bool
whether to keep the last
message time (defaults to True
) - `ValueError` : _description_
Logger
(str | None = None,
log_path: = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] int = 0,
default_level: int = 50,
console_print_threshold: = <function md_header_function>,
level_header: muutils.logger.headerfuncs.HeaderFunction dict[str | None, muutils.logger.loggingstream.LoggingStream], Sequence[muutils.logger.loggingstream.LoggingStream]] = (),
streams: Union[bool = True,
keep_last_msg_time: bool = True,
timestamp: **kwargs
)
def log
(self,
bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] = None,
msg: Union[int | None = None,
lvl: str | None = None,
stream: bool = False,
console_print: str = '',
extra_indent: **kwargs
)
logging function
msg : JSONitem
message (usually string or dict) to be
loggedlvl : int | None
level of message (lower levels are
more important) (defaults to None
)console_print : bool
override
console_print_threshold
setting (defaults to
False
)stream : str | None
whether to log to a stream
(defaults to None
), which logs to the default
None
stream (defaults to None
)def log_elapsed_last
(self,
int | None = None,
lvl: str | None = None,
stream: bool = True,
console_print: **kwargs
-> float )
logs the time elapsed since the last message was printed to the console (in any stream)
def flush_all
self) (
flush all streams
class LoggingStream:
properties of a logging stream
name: str
name of the streamaliases: set[str]
aliases for the stream (calls to
these names will be redirected to this stream. duplicate alises will
result in errors) TODO: perhaps duplicate alises should result in
duplicate writes?file: str|bool|AnyIO|None
file to write to - if
None
, will write to standard log - if True
,
will write to name + ".log"
- if False
will
“write” to NullIO
(throw it away) - if a string, will write
to that file - if a fileIO type object, will write to that objectdefault_level: int|None
default level for this
streamdefault_contents: dict[str, Callable[[], Any]]
default
contents for this streamlast_msg: tuple[float, Any]|None
last message written
to this stream (timestamp, message)LoggingStream
(str | None,
name: set[str | None] = <factory>,
aliases: file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
int | None = None,
default_level: dict[str, typing.Callable[[], typing.Any]] = <factory>,
default_contents: = None
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] )
name: str | None
aliases: set[str | None]
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
default_level: int | None = None
default_contents: dict[str, typing.Callable[[], typing.Any]]
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
def make_handler
self) -> Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] (
class SimpleLogger:
logs training data to a jsonl file
SimpleLogger
(str | None = None,
log_path: = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] bool = True
timestamp: )
def log
(self,
bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
msg: Union[bool = False,
console_print: **kwargs
)
log a message to the log file, and optionally to the console
class TimerContext:
context manager for timing code
start_time: float
end_time: float
elapsed_time: float
docs for
muutils
v0.8.7
muutils.logger.exception_context
class ExceptionContext:
context manager which catches all exceptions happening while the
context is open, .write()
the exception trace to the given
stream, and then raises the exception
for example:
= open('error.log', 'w')
errorfile
with ExceptionContext(errorfile):
# do something that might throw an exception
# if it does, the exception trace will be written to errorfile
# and then the exception will be raised
ExceptionContext
(stream)
stream
docs for
muutils
v0.8.7
muutils.logger.headerfuncs
class HeaderFunction(typing.Protocol):
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
HeaderFunction
*args, **kwargs) (
def md_header_function
(
msg: Any,int,
lvl: str | None = None,
stream: str = ' ',
indent_lvl: str = '',
extra_indent: **kwargs
-> str )
standard header function. will output
# {msg}
for levels in [0, 9]
## {msg}
for levels in [10, 19], and so on
[{stream}] # {msg}
for a non-`None` stream, with level headers as before
!WARNING! [{stream}] {msg}
for level in [-9, -1]
!!WARNING!! [{stream}] {msg}
for level in [-19, -10] and so on
HEADER_FUNCTIONS: dict[str, muutils.logger.headerfuncs.HeaderFunction] = {'md': <function md_header_function>}
docs for
muutils
v0.8.7
muutils.logger.log_util
def get_any_from_stream
list[dict], key: str) -> None (stream:
get the first value of a key from a stream. errors if not found
def gather_log
file: str) -> dict[str, list[dict]] (
gathers and sorts all streams from a log
def gather_stream
file: str, stream: str) -> list[dict] (
gets all entries from a specific stream in a log file
def gather_val
(file: str,
str,
stream: tuple[str],
keys: bool = True
allow_skip: -> list[list] )
gather specific keys from a specific stream in a log file
example: if “log.jsonl” has contents:
{"a": 1, "b": 2, "c": 3, "_stream": "s1"}
{"a": 4, "b": 5, "c": 6, "_stream": "s1"}
{"a": 7, "b": 8, "c": 9, "_stream": "s2"}
then gather_val("log.jsonl", "s1", ("a", "b"))
will
return
[1, 2],
[4, 5]
[ ]
docs for
muutils
v0.8.7
logger with streams & levels, and a timer context manager
SimpleLogger
is an extremely simple logger that can
write to both console and a fileLogger
class handles levels in a slightly different way
than default python logging
, and also has “streams” which
allow for different sorts of output in the same logger this was mostly
made with training models in mind and storing both metadata and
lossTimerContext
is a context manager that can be used to
time the duration of a block of codemuutils.logger.logger
logger with streams & levels, and a timer context manager
SimpleLogger
is an extremely simple logger that can
write to both console and a fileLogger
class handles levels in a slightly different way
than default python logging
, and also has “streams” which
allow for different sorts of output in the same logger this was mostly
made with training models in mind and storing both metadata and
lossTimerContext
is a context manager that can be used to
time the duration of a block of codedef decode_level
int) -> str (level:
class Logger(muutils.logger.simplelogger.SimpleLogger):
logger with more features, including log levels and streams
- `log_path : str | None`
default log file path
(defaults to `None`)
- `log_file : AnyIO | None`
default log io, should have a `.write()` method (pass only this or `log_path`, not both)
(defaults to `None`)
- `timestamp : bool`
whether to add timestamps to every log message (under the `_timestamp` key)
(defaults to `True`)
- `default_level : int`
default log level for streams/messages that don't specify a level
(defaults to `0`)
- `console_print_threshold : int`
log level at which to print to the console, anything greater will not be printed unless overridden by `console_print`
(defaults to `50`)
- `level_header : HeaderFunction`
function for formatting log messages when printing to console
(defaults to `HEADER_FUNCTIONS["md"]`)
keep_last_msg_time : bool
whether to keep the last
message time (defaults to True
) - `ValueError` : _description_
Logger
(str | None = None,
log_path: = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] int = 0,
default_level: int = 50,
console_print_threshold: = <function md_header_function>,
level_header: muutils.logger.headerfuncs.HeaderFunction dict[str | None, muutils.logger.loggingstream.LoggingStream], Sequence[muutils.logger.loggingstream.LoggingStream]] = (),
streams: Union[bool = True,
keep_last_msg_time: bool = True,
timestamp: **kwargs
)
def log
(self,
bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]] = None,
msg: Union[int | None = None,
lvl: str | None = None,
stream: bool = False,
console_print: str = '',
extra_indent: **kwargs
)
logging function
msg : JSONitem
message (usually string or dict) to be
loggedlvl : int | None
level of message (lower levels are
more important) (defaults to None
)console_print : bool
override
console_print_threshold
setting (defaults to
False
)stream : str | None
whether to log to a stream
(defaults to None
), which logs to the default
None
stream (defaults to None
)def log_elapsed_last
(self,
int | None = None,
lvl: str | None = None,
stream: bool = True,
console_print: **kwargs
-> float )
logs the time elapsed since the last message was printed to the console (in any stream)
def flush_all
self) (
flush all streams
docs for
muutils
v0.8.7
muutils.logger.loggingstream
class LoggingStream:
properties of a logging stream
name: str
name of the streamaliases: set[str]
aliases for the stream (calls to
these names will be redirected to this stream. duplicate alises will
result in errors) TODO: perhaps duplicate alises should result in
duplicate writes?file: str|bool|AnyIO|None
file to write to - if
None
, will write to standard log - if True
,
will write to name + ".log"
- if False
will
“write” to NullIO
(throw it away) - if a string, will write
to that file - if a fileIO type object, will write to that objectdefault_level: int|None
default level for this
streamdefault_contents: dict[str, Callable[[], Any]]
default
contents for this streamlast_msg: tuple[float, Any]|None
last message written
to this stream (timestamp, message)LoggingStream
(str | None,
name: set[str | None] = <factory>,
aliases: file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
int | None = None,
default_level: dict[str, typing.Callable[[], typing.Any]] = <factory>,
default_contents: = None
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] )
name: str | None
aliases: set[str | None]
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
default_level: int | None = None
default_contents: dict[str, typing.Callable[[], typing.Any]]
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
def make_handler
self) -> Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] (
docs for
muutils
v0.8.7
muutils.logger.simplelogger
class NullIO:
null IO class
def write
self, msg: str) -> int (
write to nothing! this throws away the message
def flush
self) -> None (
flush nothing! this is a no-op
def close
self) -> None (
close nothing! this is a no-op
AnyIO = typing.Union[typing.TextIO, muutils.logger.simplelogger.NullIO]
class SimpleLogger:
logs training data to a jsonl file
SimpleLogger
(str | None = None,
log_path: = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] bool = True
timestamp: )
def log
(self,
bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]],
msg: Union[bool = False,
console_print: **kwargs
)
log a message to the log file, and optionally to the console
docs for
muutils
v0.8.7
muutils.logger.timing
class TimerContext:
context manager for timing code
start_time: float
end_time: float
elapsed_time: float
def filter_time_str
str) -> str (time:
assuming format h:mm:ss
, clips off the hours if its
0
class ProgressEstimator:
estimates progress and can give a progress bar
ProgressEstimator
(int,
n_total: str = '█',
pbar_fill: str = ' ',
pbar_empty: tuple[str, str] = ('|', '|')
pbar_bounds: )
n_total: int
starttime: float
pbar_fill: str
pbar_empty: str
pbar_bounds: tuple[str, str]
total_str_len: int
def get_timing_raw
self, i: int) -> dict[str, float] (
returns dict(elapsed, per_iter, remaining, percent)
def get_pbar
self, i: int, width: int = 30) -> str (
returns a progress bar
def get_progress_default
self, i: int) -> str (
returns a progress string
docs for
muutils
v0.8.7
miscellaneous utilities
stable_hash
for hashing that is stable across runsmuutils.misc.sequence
for sequence manipulation,
applying mappings, and string-like operations on listsmuutils.misc.string
for sanitizing things for
filenames, adjusting docstrings, and converting dicts to filenamesmuutils.misc.numerical
for turning numbers into nice
strings and backmuutils.misc.freezing
for freezing thingsmuutils.misc.classes
for some weird class
utilitiesstable_hash
WhenMissing
empty_sequence_if_attr_false
flatten
list_split
list_join
apply_mapping
apply_mapping_chain
sanitize_name
sanitize_fname
sanitize_identifier
dict_to_filename
dynamic_docstring
shorten_numerical_to_str
str_to_numeric
_SHORTEN_MAP
FrozenDict
FrozenList
freeze
is_abstract
get_all_subclasses
isinstance_by_type_name
IsDataclass
get_hashable_eq_attrs
dataclass_set_equals
muutils.misc
miscellaneous utilities
stable_hash
for hashing that is stable across runs<a href="misc/sequence.html">muutils.misc.sequence</a>
for sequence manipulation, applying mappings, and string-like operations
on lists<a href="misc/string.html">muutils.misc.string</a>
for sanitizing things for filenames, adjusting docstrings, and
converting dicts to filenames<a href="misc/numerical.html">muutils.misc.numerical</a>
for turning numbers into nice strings and back<a href="misc/freezing.html">muutils.misc.freezing</a>
for freezing things<a href="misc/classes.html">muutils.misc.classes</a>
for some weird class utilitiesdef stable_hash
str | bytes) -> int (s:
Returns a stable hash of the given string. not cryptographically secure, but stable between runs
WhenMissing = typing.Literal['except', 'skip', 'include']
def empty_sequence_if_attr_false
str) -> Iterable[Any] (itr: Iterable[Any], attr_owner: Any, attr_name:
Returns itr
if attr_owner
has the attribute
attr_name
and it boolean casts to True
.
Returns an empty sequence otherwise.
Particularly useful for optionally inserting delimiters into a
sequence depending on an TokenizerElement
attribute.
itr: Iterable[Any]
The iterable to return if the
attribute is True
.attr_owner: Any
The object to check for the
attribute.attr_name: str
The name of the attribute to check.itr: Iterable
if attr_owner
has the
attribute attr_name
and it boolean casts to
True
, otherwise an empty sequence.()
an empty sequence if the attribute is
False
or not present.def flatten
int | None = None) -> Generator (it: Iterable[Any], levels_to_flatten:
Flattens an arbitrarily nested iterable. Flattens all iterable data
types except for str
and bytes
.
Generator over the flattened sequence.
it
: Any arbitrarily nested iterable.levels_to_flatten
: Number of levels to flatten by,
starting at the outermost layer. If None
, performs full
flattening.def list_split
list, val: Any) -> list[list] (lst:
split a list into sublists by val
. similar to
“a_b_c”.split(“_“)
>>> list_split([1,2,3,0,4,5,0,6], 0)
1, 2, 3], [4, 5], [6]]
[[>>> list_split([0,1,2,3], 0)
1, 2, 3]]
[[], [>>> list_split([1,2,3], 0)
1, 2, 3]]
[[>>> list_split([], 0)
[[]]
def list_join
list, factory: Callable) -> list (lst:
add a new instance of factory()
between each
element of lst
>>> list_join([1,2,3], lambda : 0)
1,0,2,0,3]
[>>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1])
1, 1600000000.0, 2, 1600000000.1, 3] [
def apply_mapping
(~_AM_K, ~_AM_V],
mapping: Mapping[iter: Iterable[~_AM_K],
'except', 'skip', 'include'] = 'skip'
when_missing: Literal[-> list[typing.Union[~_AM_K, ~_AM_V]] )
Given an iterable and a mapping, apply the mapping to the iterable with certain options
Gotcha: if when_missing
is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, _AM_V]
must have
__contains__
and __getitem__
, both of which
take _AM_K
and the latter returns _AM_V
iter : Iterable[_AM_K]
the iterable to apply the
mapping towhen_missing : WhenMissing
what to do when a key is
missing from the mapping – this is what distinguishes this function from
map
you can choose from "skip"
,
"include"
(without converting), and "except"
(defaults to "skip"
)return type is one of: - list[_AM_V]
if
when_missing
is "skip"
or
"except"
- list[Union[_AM_K, _AM_V]]
if
when_missing
is "include"
KeyError
: if the item is missing from the mapping and
when_missing
is "except"
ValueError
: if when_missing
is
invaliddef apply_mapping_chain
(~_AM_K, Iterable[~_AM_V]],
mapping: Mapping[iter: Iterable[~_AM_K],
'except', 'skip', 'include'] = 'skip'
when_missing: Literal[-> list[typing.Union[~_AM_K, ~_AM_V]] )
Given an iterable and a mapping, chain the mappings together
Gotcha: if when_missing
is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, Iterable[_AM_V]]
must have
__contains__
and __getitem__
, both of which
take _AM_K
and the latter returns
Iterable[_AM_V]
iter : Iterable[_AM_K]
the iterable to apply the
mapping towhen_missing : WhenMissing
what to do when a key is
missing from the mapping – this is what distinguishes this function from
map
you can choose from "skip"
,
"include"
(without converting), and "except"
(defaults to "skip"
)return type is one of: - list[_AM_V]
if
when_missing
is "skip"
or
"except"
- list[Union[_AM_K, _AM_V]]
if
when_missing
is "include"
KeyError
: if the item is missing from the mapping and
when_missing
is "except"
ValueError
: if when_missing
is
invaliddef sanitize_name
(str | None,
name: str = '',
additional_allowed_chars: str = '',
replace_invalid: str | None = '_None_',
when_none: str = ''
leading_digit_prefix: -> str )
sanitize a string, leaving only alphanumerics and
additional_allowed_chars
name : str | None
input stringadditional_allowed_chars : str
additional characters to
allow, none by default (defaults to ""
)replace_invalid : str
character to replace invalid
characters with (defaults to ""
)when_none : str | None
string to return if
name
is None
. if None
, raises an
exception (defaults to "_None_"
)leading_digit_prefix : str
character to prefix the
string with if it starts with a digit (defaults to ""
)str
sanitized stringdef sanitize_fname
str | None, **kwargs) -> str (fname:
sanitize a filename to posix standards
_
(underscore), ‘-’ (dash)
and .
(period)def sanitize_identifier
str | None, **kwargs) -> str (fname:
sanitize an identifier (variable or function name)
_
(underscore)_
if it starts with a digitdef dict_to_filename
(dict,
data: str = '{key}_{val}',
format_str: str = '.',
separator: int = 255
max_length: )
def dynamic_docstring
**doc_params) (
def shorten_numerical_to_str
(int | float,
num: bool = True,
small_as_decimal: int = 1
precision: -> str )
shorten a large numerical value to a string 1234 -> 1K
precision guaranteed to 1 in 10, but can be higher. reverse of
str_to_numeric
def str_to_numeric
(str,
quantity: None | bool | dict[str, int | float] = True
mapping: -> int | float )
Convert a string representing a quantity to a numeric value.
The string can represent an integer, python float, fraction, or
shortened via shorten_numerical_to_str
.
>>> str_to_numeric("5")
5
>>> str_to_numeric("0.1")
0.1
>>> str_to_numeric("1/5")
0.2
>>> str_to_numeric("-1K")
-1000.0
>>> str_to_numeric("1.5M")
1500000.0
>>> str_to_numeric("1.2e2")
120.0
_SHORTEN_MAP = {1000.0: 'K', 1000000.0: 'M', 1000000000.0: 'B', 1000000000000.0: 't', 1000000000000000.0: 'q', 1e+18: 'Q'}
class FrozenDict(builtins.dict):
class FrozenList(builtins.list):
Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.
def append
self, value) (
Append object to the end of the list.
def extend
self, iterable) (
Extend list by appending elements from the iterable.
def insert
self, index, value) (
Insert object before index.
def remove
self, value) (
Remove first occurrence of value.
Raises ValueError if the value is not present.
def pop
self, index=-1) (
Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
def clear
self) (
Remove all items from list.
def freeze
-> Any (instance: Any)
recursively freeze an object in-place so that its attributes and elements cannot be changed
messy in the sense that sometimes the object is modified in place, but you can’t rely on that. always use the return value.
the gelidum package is a more complete implementation of this idea
def is_abstract
type) -> bool (cls:
Returns if a class is abstract.
def get_all_subclasses
type, include_self=False) -> set[type] (class_:
Returns a set containing all child classes in the subclass graph of
class_
. I.e., includes subclasses of subclasses, etc.
include_self
: Whether to include class_
itself in the returned setclass_
: SuperclassSince most class hierarchies are small, the inefficiencies of the existing recursive implementation aren’t problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
def isinstance_by_type_name
object, type_name: str) (o:
Behaves like stdlib isinstance
except it accepts a
string representation of the type rather than the type itself. This is a
hacky function intended to circumvent the need to import a type into a
module. It is susceptible to type name collisions.
o
: Object (not the type itself) whose type to
interrogate type_name
: The string returned by
type_.__name__
. Generic types are not supported, only types
that would appear in type_.__mro__
.
class IsDataclass(typing.Protocol):
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
IsDataclass
*args, **kwargs) (
def get_hashable_eq_attrs
-> tuple[typing.Any] (dc: muutils.misc.classes.IsDataclass)
Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it’s not frozen.
def dataclass_set_equals
(
coll1: Iterable[muutils.misc.classes.IsDataclass],
coll2: Iterable[muutils.misc.classes.IsDataclass]-> bool )
Compares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can’t be placed in sets since they’re not hashable. Collections of them may be compared using this function.
docs for
muutils
v0.8.7
is_abstract
get_all_subclasses
isinstance_by_type_name
IsDataclass
get_hashable_eq_attrs
dataclass_set_equals
muutils.misc.classes
def is_abstract
type) -> bool (cls:
Returns if a class is abstract.
def get_all_subclasses
type, include_self=False) -> set[type] (class_:
Returns a set containing all child classes in the subclass graph of
class_
. I.e., includes subclasses of subclasses, etc.
include_self
: Whether to include class_
itself in the returned setclass_
: SuperclassSince most class hierarchies are small, the inefficiencies of the existing recursive implementation aren’t problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
def isinstance_by_type_name
object, type_name: str) (o:
Behaves like stdlib isinstance
except it accepts a
string representation of the type rather than the type itself. This is a
hacky function intended to circumvent the need to import a type into a
module. It is susceptible to type name collisions.
o
: Object (not the type itself) whose type to
interrogate type_name
: The string returned by
type_.__name__
. Generic types are not supported, only types
that would appear in type_.__mro__
.
class IsDataclass(typing.Protocol):
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
IsDataclass
*args, **kwargs) (
def get_hashable_eq_attrs
-> tuple[typing.Any] (dc: muutils.misc.classes.IsDataclass)
Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it’s not frozen.
def dataclass_set_equals
(
coll1: Iterable[muutils.misc.classes.IsDataclass],
coll2: Iterable[muutils.misc.classes.IsDataclass]-> bool )
Compares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can’t be placed in sets since they’re not hashable. Collections of them may be compared using this function.
docs for
muutils
v0.8.7
muutils.misc.freezing
class FrozenDict(builtins.dict):
class FrozenList(builtins.list):
Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.
def append
self, value) (
Append object to the end of the list.
def extend
self, iterable) (
Extend list by appending elements from the iterable.
def insert
self, index, value) (
Insert object before index.
def remove
self, value) (
Remove first occurrence of value.
Raises ValueError if the value is not present.
def pop
self, index=-1) (
Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
def clear
self) (
Remove all items from list.
def freeze
-> Any (instance: Any)
recursively freeze an object in-place so that its attributes and elements cannot be changed
messy in the sense that sometimes the object is modified in place, but you can’t rely on that. always use the return value.
the gelidum package is a more complete implementation of this idea
docs for
muutils
v0.8.7
FuncParams
FuncParamsPreWrap
process_kwarg
validate_kwarg
replace_kwarg
is_none
always_true
always_false
format_docstring
LambdaArgs
typed_lambda
muutils.misc.func
FuncParams = ~FuncParams
FuncParamsPreWrap = ~FuncParamsPreWrap
def process_kwarg
(str,
kwarg_name: ~T_process_in], ~T_process_out]
processor: Callable[[-> Callable[[Callable[~FuncParamsPreWrap, ~ReturnType]], Callable[~FuncParams, ~ReturnType]] )
Decorator that applies a processor to a keyword argument.
The underlying function is expected to have a keyword argument (with
name kwarg_name
) of type T_out
, but the caller
provides a value of type T_in
that is converted via
processor
.
kwarg_name : str
The name of the keyword argument to
process.processor : Callable[[T_in], T_out]
A callable that
converts the input value (T_in
) into the type expected by
the function (T_out
).Callable[OutputParams, ReturnType]
(expecting
kwarg_name
of type T_out
) into one of type
Callable[InputParams, ReturnType]
(accepting
kwarg_name
of type T_in
).def validate_kwarg
(str,
kwarg_name: ~T_kwarg], bool],
validator: Callable[[str | None = None,
description: = ErrorMode.Except
action: muutils.errormode.ErrorMode -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]] )
Decorator that validates a specific keyword argument.
kwarg_name : str
The name of the keyword argument to
validate.validator : Callable[[Any], bool]
A callable that
returns True if the keyword argument is valid.description : str | None
A message template if
validation fails.action : str
Either "raise"
(default) or
"warn"
.Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]
A decorator that validates the keyword argument.action=="warn"
, emits a
warning. Otherwise, raises a ValueError.@validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
def my_func(x: int) -> int:
return x
assert my_func(x=1) == 1
ValueError
if validation fails and
action == "raise"
.def replace_kwarg
(str,
kwarg_name: ~T_kwarg], bool],
check: Callable[[~T_kwarg,
replacement_value: bool = False
replace_if_missing: -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]] )
Decorator that replaces a specific keyword argument value by identity comparison.
kwarg_name : str
The name of the keyword argument to
replace.check : Callable[[T_kwarg], bool]
A callable that
returns True if the keyword argument should be replaced.replacement_value : T_kwarg
The value to replace
with.replace_if_missing : bool
If True, replaces the keyword
argument even if it’s missing.Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]
A decorator that replaces the keyword argument value.kwargs[kwarg_name]
if its value is
default_value
.@replace_kwarg("x", None, "default_string")
def my_func(*, x: str | None = None) -> str:
return x
assert my_func(x=None) == "default_string"
def is_none
-> bool (value: Any)
def always_true
-> bool (value: Any)
def always_false
-> bool (value: Any)
def format_docstring
(**fmt_kwargs: Any
-> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]] )
Decorator that formats a function’s docstring with the provided keyword arguments.
LambdaArgs = LambdaArgs
def typed_lambda
(~ReturnType],
fn: Callable[[Unpack[LambdaArgs]], ~LambdaArgsTypes,
in_types: type[~ReturnType]
out_type: -> Callable[[Unpack[LambdaArgs]], ~ReturnType] )
Wraps a lambda function with type hints.
fn : Callable[[Unpack[LambdaArgs]], ReturnType]
The
lambda function to wrap.in_types : tuple[type, ...]
Tuple of input types.out_type : type[ReturnType]
The output type.Callable[..., ReturnType]
A new function with
annotations matching the given signature.= typed_lambda(lambda x, y: x + y, (int, int), int)
add assert add(1, 2) == 3
ValueError
if the number of input types doesn’t match
the lambda’s parameters.docs for
muutils
v0.8.7
muutils.misc.hashing
def stable_hash
str | bytes) -> int (s:
Returns a stable hash of the given string. not cryptographically secure, but stable between runs
def stable_json_dumps
-> str (d)
def base64_hash
str | bytes) -> str (s:
Returns a base64 representation of the hash of the given string. not cryptographically secure
docs for
muutils
v0.8.7
muutils.misc.numerical
def shorten_numerical_to_str
(int | float,
num: bool = True,
small_as_decimal: int = 1
precision: -> str )
shorten a large numerical value to a string 1234 -> 1K
precision guaranteed to 1 in 10, but can be higher. reverse of
str_to_numeric
def str_to_numeric
(str,
quantity: None | bool | dict[str, int | float] = True
mapping: -> int | float )
Convert a string representing a quantity to a numeric value.
The string can represent an integer, python float, fraction, or
shortened via shorten_numerical_to_str
.
>>> str_to_numeric("5")
5
>>> str_to_numeric("0.1")
0.1
>>> str_to_numeric("1/5")
0.2
>>> str_to_numeric("-1K")
-1000.0
>>> str_to_numeric("1.5M")
1500000.0
>>> str_to_numeric("1.2e2")
120.0
docs for
muutils
v0.8.7
WhenMissing
empty_sequence_if_attr_false
flatten
list_split
list_join
apply_mapping
apply_mapping_chain
muutils.misc.sequence
WhenMissing = typing.Literal['except', 'skip', 'include']
def empty_sequence_if_attr_false
str) -> Iterable[Any] (itr: Iterable[Any], attr_owner: Any, attr_name:
Returns itr
if attr_owner
has the attribute
attr_name
and it boolean casts to True
.
Returns an empty sequence otherwise.
Particularly useful for optionally inserting delimiters into a
sequence depending on an TokenizerElement
attribute.
itr: Iterable[Any]
The iterable to return if the
attribute is True
.attr_owner: Any
The object to check for the
attribute.attr_name: str
The name of the attribute to check.itr: Iterable
if attr_owner
has the
attribute attr_name
and it boolean casts to
True
, otherwise an empty sequence.()
an empty sequence if the attribute is
False
or not present.def flatten
int | None = None) -> Generator (it: Iterable[Any], levels_to_flatten:
Flattens an arbitrarily nested iterable. Flattens all iterable data
types except for str
and bytes
.
Generator over the flattened sequence.
it
: Any arbitrarily nested iterable.levels_to_flatten
: Number of levels to flatten by,
starting at the outermost layer. If None
, performs full
flattening.def list_split
list, val: Any) -> list[list] (lst:
split a list into sublists by val
. similar to
“a_b_c”.split(“_“)
>>> list_split([1,2,3,0,4,5,0,6], 0)
1, 2, 3], [4, 5], [6]]
[[>>> list_split([0,1,2,3], 0)
1, 2, 3]]
[[], [>>> list_split([1,2,3], 0)
1, 2, 3]]
[[>>> list_split([], 0)
[[]]
def list_join
list, factory: Callable) -> list (lst:
add a new instance of factory()
between each
element of lst
>>> list_join([1,2,3], lambda : 0)
1,0,2,0,3]
[>>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1])
1, 1600000000.0, 2, 1600000000.1, 3] [
def apply_mapping
(~_AM_K, ~_AM_V],
mapping: Mapping[iter: Iterable[~_AM_K],
'except', 'skip', 'include'] = 'skip'
when_missing: Literal[-> list[typing.Union[~_AM_K, ~_AM_V]] )
Given an iterable and a mapping, apply the mapping to the iterable with certain options
Gotcha: if when_missing
is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="../kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, _AM_V]
must have
__contains__
and __getitem__
, both of which
take _AM_K
and the latter returns _AM_V
iter : Iterable[_AM_K]
the iterable to apply the
mapping towhen_missing : WhenMissing
what to do when a key is
missing from the mapping – this is what distinguishes this function from
map
you can choose from "skip"
,
"include"
(without converting), and "except"
(defaults to "skip"
)return type is one of: - list[_AM_V]
if
when_missing
is "skip"
or
"except"
- list[Union[_AM_K, _AM_V]]
if
when_missing
is "include"
KeyError
: if the item is missing from the mapping and
when_missing
is "except"
ValueError
: if when_missing
is
invaliddef apply_mapping_chain
(~_AM_K, Iterable[~_AM_V]],
mapping: Mapping[iter: Iterable[~_AM_K],
'except', 'skip', 'include'] = 'skip'
when_missing: Literal[-> list[typing.Union[~_AM_K, ~_AM_V]] )
Given an iterable and a mapping, chain the mappings together
Gotcha: if when_missing
is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="../kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, Iterable[_AM_V]]
must have
__contains__
and __getitem__
, both of which
take _AM_K
and the latter returns
Iterable[_AM_V]
iter : Iterable[_AM_K]
the iterable to apply the
mapping towhen_missing : WhenMissing
what to do when a key is
missing from the mapping – this is what distinguishes this function from
map
you can choose from "skip"
,
"include"
(without converting), and "except"
(defaults to "skip"
)return type is one of: - list[_AM_V]
if
when_missing
is "skip"
or
"except"
- list[Union[_AM_K, _AM_V]]
if
when_missing
is "include"
KeyError
: if the item is missing from the mapping and
when_missing
is "except"
ValueError
: if when_missing
is
invaliddocs for
muutils
v0.8.7
muutils.misc.string
def sanitize_name
(str | None,
name: str = '',
additional_allowed_chars: str = '',
replace_invalid: str | None = '_None_',
when_none: str = ''
leading_digit_prefix: -> str )
sanitize a string, leaving only alphanumerics and
additional_allowed_chars
name : str | None
input stringadditional_allowed_chars : str
additional characters to
allow, none by default (defaults to ""
)replace_invalid : str
character to replace invalid
characters with (defaults to ""
)when_none : str | None
string to return if
name
is None
. if None
, raises an
exception (defaults to "_None_"
)leading_digit_prefix : str
character to prefix the
string with if it starts with a digit (defaults to ""
)str
sanitized stringdef sanitize_fname
str | None, **kwargs) -> str (fname:
sanitize a filename to posix standards
_
(underscore), ‘-’ (dash)
and .
(period)def sanitize_identifier
str | None, **kwargs) -> str (fname:
sanitize an identifier (variable or function name)
_
(underscore)_
if it starts with a digitdef dict_to_filename
(dict,
data: str = '{key}_{val}',
format_str: str = '.',
separator: int = 255
max_length: )
def dynamic_docstring
**doc_params) (
docs for
muutils
v0.8.7
miscellaneous utilities for ML pipelines
ARRAY_IMPORTS
DEFAULT_SEED
GLOBAL_SEED
get_device
set_reproducibility
chunks
get_checkpoint_paths_for_run
register_method
pprint_summary
muutils.mlutils
miscellaneous utilities for ML pipelines
ARRAY_IMPORTS: bool = True
DEFAULT_SEED: int = 42
GLOBAL_SEED: int = 42
def get_device
str, torch.device, NoneType] = None) -> torch.device (device: Union[
Get the torch.device instance on which torch.Tensor
s
should be allocated.
def set_reproducibility
int = 42) (seed:
Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
def chunks
(it, chunk_size)
Yield successive chunks from an iterator.
def get_checkpoint_paths_for_run
(
run_path: pathlib.Path,'pt', 'zanj'],
extension: Literal[str = 'checkpoints/model.iter_*.{extension}'
checkpoints_format: -> list[tuple[int, pathlib.Path]] )
get checkpoints of the format from the run_path
note that checkpoints_format
should contain a glob
pattern with: - unresolved “{extension}” format term for the extension -
a wildcard for the iteration number
def register_method
(dict[str, typing.Callable[..., typing.Any]],
method_dict: str] = None
custom_name: Optional[-> Callable[[~F], ~F] )
Decorator to add a method to the method_dict
def pprint_summary
dict) (summary:
docs for
muutils
v0.8.7
utilities for working with notebooks
configure_notebook
convert_ipynb_to_script
run_notebook_tests
mermaid
,
print_tex
muutils.nbutils
utilities for working with notebooks
configure_notebook
convert_ipynb_to_script
run_notebook_tests
mermaid
,
print_tex
def mm
(graph)
for plotting mermaid.js diagrams
docs for
muutils
v0.8.7
shared utilities for setting up a notebook
PlotlyNotInstalledWarning
PLOTLY_IMPORTED
PlottingMode
PLOT_MODE
CONVERSION_PLOTMODE_OVERRIDE
FIG_COUNTER
FIG_OUTPUT_FMT
FIG_NUMBERED_FNAME
FIG_CONFIG
FIG_BASEPATH
CLOSE_AFTER_PLOTSHOW
MATPLOTLIB_FORMATS
TIKZPLOTLIB_FORMATS
UnknownFigureFormatWarning
universal_savefig
setup_plots
configure_notebook
plotshow
muutils.nbutils.configure_notebook
shared utilities for setting up a notebook
class PlotlyNotInstalledWarning(builtins.UserWarning):
Base class for warnings generated by user code.
PLOTLY_IMPORTED: bool = True
PlottingMode = typing.Literal['ignore', 'inline', 'widget', 'save']
PLOT_MODE: Literal['ignore', 'inline', 'widget', 'save'] = 'inline'
CONVERSION_PLOTMODE_OVERRIDE: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None
FIG_COUNTER: int = 0
FIG_OUTPUT_FMT: str | None = None
FIG_NUMBERED_FNAME: str = 'figure-{num}'
FIG_CONFIG: dict | None = None
FIG_BASEPATH: str | None = None
CLOSE_AFTER_PLOTSHOW: bool = False
MATPLOTLIB_FORMATS = ['pdf', 'png', 'jpg', 'jpeg', 'svg', 'eps', 'ps', 'tif', 'tiff']
TIKZPLOTLIB_FORMATS = ['tex', 'tikz']
class UnknownFigureFormatWarning(builtins.UserWarning):
Base class for warnings generated by user code.
def universal_savefig
str, fmt: str | None = None) -> None (fname:
def setup_plots
('ignore', 'inline', 'widget', 'save'] = 'inline',
plot_mode: Literal[str | None = 'pdf',
fig_output_fmt: str = 'figure-{num}',
fig_numbered_fname: dict | None = None,
fig_config: str | None = None,
fig_basepath: bool = False
close_after_plotshow: -> None )
Set up plot saving/rendering options
def configure_notebook
(*args,
int = 42,
seed: = None,
device: Any bool = True,
dark_mode: 'ignore', 'inline', 'widget', 'save'] = 'inline',
plot_mode: Literal[str | None = 'pdf',
fig_output_fmt: str = 'figure-{num}',
fig_numbered_fname: dict | None = None,
fig_config: str | None = None,
fig_basepath: bool = False
close_after_plotshow: -> torch.device | None )
Shared Jupyter notebook setup steps
seed : int
random seed across libraries including
torch, numpy, and random (defaults to 42
) (defaults to
42
)device : typing.Any
pytorch device to use (defaults to
None
)dark_mode : bool
figures in dark mode (defaults to
True
)plot_mode : PlottingMode
how to display plots, one of
PlottingMode
or
["ignore", "inline", "widget", "save"]
(defaults to
"inline"
)fig_output_fmt : str | None
format for saving figures
(defaults to "pdf"
)fig_numbered_fname : str
format for saving figures with
numbers (if they aren’t named) (defaults to
"figure-{num}"
)fig_config : dict | None
metadata to save with the
figures (defaults to None
)fig_basepath : str | None
base path for saving figures
(defaults to None
)close_after_plotshow : bool
close figures after showing
them (defaults to False
)torch.device|None
the device set, if torch is
installeddef plotshow
(str | None = None,
fname: 'ignore', 'inline', 'widget', 'save']] = None,
plot_mode: Optional[Literal[str | None = None
fmt: )
Show the active plot, depending on global configs
docs for
muutils
v0.8.7
fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.
muutils.nbutils.convert_ipynb_to_script
fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.
DISABLE_PLOTS: dict[str, list[str]] = {'matplotlib': ['\n# ------------------------------------------------------------\n# Disable matplotlib plots, done during processing by
convert_ipynb_to_script.py\nimport matplotlib.pyplot as plt\nplt.show = lambda: None\n# ------------------------------------------------------------\n'], 'circuitsvis': ['\n# ------------------------------------------------------------\n# Disable circuitsvis plots, done during processing by
convert_ipynb_to_script.py\nfrom circuitsvis.utils.convert_props import PythonProperty, convert_props\nfrom circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local\n\ndef new_render(\n react_element_name: str,\n **kwargs: PythonProperty\n) -> RenderedHTML:\n "return a visualization as raw HTML"\n local_src = render_local(react_element_name, **kwargs)\n cdn_src = render_cdn(react_element_name, **kwargs)\n # return as string instead of RenderedHTML for CI\n return str(RenderedHTML(local_src, cdn_src))\n\nrender = new_render\n# ------------------------------------------------------------\n'], 'muutils': ['import muutils.nbutils.configure_notebook as nb_conf\nnb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore"\n']}
DISABLE_PLOTS_WARNING: list[str] = ["# ------------------------------------------------------------\n# WARNING: this script is auto-generated by
convert_ipynb_to_script.py\n# showing plots has been disabled, so this is presumably in a temp dict for CI or something\n# so don't modify this code, it will be overwritten!\n# ------------------------------------------------------------\n"]
def disable_plots_in_script
list[str]) -> list[str] (script_lines:
Disable plots in a script by adding cursed things after the import statements
def convert_ipynb
(dict,
notebook: bool = False,
strip_md_cells: str = '#%%',
header_comment: bool = False,
disable_plots: str, Sequence[str]] = ('%', '!')
filter_out_lines: Union[-> str )
Convert Jupyter Notebook to a script, doing some basic filtering and formatting.
- `notebook: dict`: Jupyter Notebook loaded as json.
- `strip_md_cells: bool = False`: Remove markdown cells from the output script.
- `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
- `disable_plots: bool = False`: Disable plots in the output script.
- `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
if a string is passed, it will be split by char and each char will be treated as a separate filter.
- `str`: Converted script.
def process_file
(str,
in_file: str | None = None,
out_file: bool = False,
strip_md_cells: str = '#%%',
header_comment: bool = False,
disable_plots: str, Sequence[str]] = ('%', '!')
filter_out_lines: Union[ )
def process_dir
(str, pathlib.Path],
input_dir: Union[str, pathlib.Path],
output_dir: Union[bool = False,
strip_md_cells: str = '#%%',
header_comment: bool = False,
disable_plots: str, Sequence[str]] = ('%', '!')
filter_out_lines: Union[ )
Convert all Jupyter Notebooks in a directory to scripts.
- `input_dir: str`: Input directory.
- `output_dir: str`: Output directory.
- `strip_md_cells: bool = False`: Remove markdown cells from the output script.
- `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
- `disable_plots: bool = False`: Disable plots in the output script.
- `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
if a string is passed, it will be split by char and each char will be treated as a separate filter.
docs for
muutils
v0.8.7
display mermaid.js diagrams in jupyter notebooks by the
mermaid.ink/img
service
muutils.nbutils.mermaid
display mermaid.js diagrams in jupyter notebooks by the
mermaid.ink/img
service
def mm
(graph)
for plotting mermaid.js diagrams
docs for
muutils
v0.8.7
quickly print a sympy expression in latex
muutils.nbutils.print_tex
quickly print a sympy expression in latex
def print_tex
(
expr: sympy.core.expr.Expr,str | None = None,
name: bool = False,
plain: bool = True
rendered: )
function for easily rendering a sympy expression in latex
docs for
muutils
v0.8.7
turn a folder of notebooks into scripts, run them, and make sure they work.
made to be called as
python -m muutils.nbutils.run_notebook_tests --notebooks-dir <notebooks_dir> --converted-notebooks-temp-dir <converted_notebooks_temp_dir>
muutils.nbutils.run_notebook_tests
turn a folder of notebooks into scripts, run them, and make sure they work.
made to be called as
python -m <a href="">muutils.nbutils.run_notebook_tests</a> --notebooks-dir <notebooks_dir> --converted-notebooks-temp-dir <converted_notebooks_temp_dir>
class NotebookTestError(builtins.Exception):
Common base class for all non-exit exceptions.
SUCCESS_STR: str = '✅'
FAILURE_STR: str = '❌'
def run_notebook_tests
(
notebooks_dir: pathlib.Path,
converted_notebooks_temp_dir: pathlib.Path,str = '.CI-output.txt',
CI_output_suffix: str] = None,
run_python_cmd: Optional[str = '{python_tool} run python',
run_python_cmd_fmt: str = 'poetry',
python_tool: bool = False
exit_on_first_fail: )
Run converted Jupyter notebooks as Python scripts and verify they execute successfully.
Takes a directory of notebooks and their corresponding converted Python scripts, executes each script, and captures the output. Failures are collected and reported, with optional early exit on first failure.
notebooks_dir : Path
Directory containing the original
.ipynb notebook filesconverted_notebooks_temp_dir : Path
Directory
containing the corresponding converted .py filesCI_output_suffix : str
Suffix to append to output files
capturing execution results (defaults to
".CI-output.txt"
)run_python_cmd : str | None
Custom command to run
Python scripts. Overrides python_tool and run_python_cmd_fmt if provided
(defaults to None
)run_python_cmd_fmt : str
Format string for constructing
the Python run command (defaults to
"{python_tool} run python"
)python_tool : str
Tool used to run Python (e.g. poetry,
uv) (defaults to "poetry"
)exit_on_first_fail : bool
Whether to raise exception
immediately on first notebook failure (defaults to
False
)None
NotebookTestError
: If any notebooks fail to execute, or
if input directories are invalidTypeError
: If run_python_cmd is provided but not a
string>>> run_notebook_tests(
=Path("notebooks"),
... notebooks_dir=Path("temp/converted"),
... converted_notebooks_temp_dir="poetry"
... python_tool
... )### testing notebooks in 'notebooks'
### reading converted notebooks from 'temp/converted'
1/2: temp/converted/notebook1.py
Running in temp/converted/notebook1.CI-output.txt
Output with return code 0 {SUCCESS_STR} Run completed
docs for
muutils
v0.8.7
ProgressBarFunction
ProgressBarOption
DEFAULT_PBAR_FN
spinner_fn_wrap
map_kwargs_for_tqdm
no_progress_fn_wrap
set_up_progress_bar_fn
run_maybe_parallel
muutils.parallel
class ProgressBarFunction(typing.Protocol):
a protocol for a progress bar function
ProgressBarFunction
*args, **kwargs) (
ProgressBarOption = typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] = 'tqdm'
def spinner_fn_wrap
**kwargs) -> List (x: Iterable,
spinner wrapper
def map_kwargs_for_tqdm
dict) -> dict (kwargs:
map kwargs for tqdm, cant wrap because the pbar dissapears?
def no_progress_fn_wrap
**kwargs) -> Iterable (x: Iterable,
fallback to no progress bar
def set_up_progress_bar_fn
('tqdm', 'spinner', 'none', None]],
pbar: Union[muutils.parallel.ProgressBarFunction, Literal[str, Any]] = None,
pbar_kwargs: Optional[Dict[**extra_kwargs
-> Tuple[muutils.parallel.ProgressBarFunction, dict] )
set up the progress bar function and its kwargs
pbar : Union[ProgressBarFunction, ProgressBarOption]
progress bar function or option. if a function, we return as-is. if a
string, we figure out which progress bar to usepbar_kwargs : Optional[Dict[str, Any]]
kwargs passed to
the progress bar function (default to None
) (defaults to
None
)Tuple[ProgressBarFunction, dict]
a tuple of the
progress bar function and its kwargsValueError
: if pbar
is not one of the
valid optionsdef run_maybe_parallel
(~InputType], ~OutputType],
func: Callable[[~InputType],
iterable: Iterable[bool, int],
parallel: Union[str, Any]] = None,
pbar_kwargs: Optional[Dict[int] = None,
chunksize: Optional[bool = True,
keep_ordered: bool = False,
use_multiprocess: 'tqdm', 'spinner', 'none', None]] = 'tqdm'
pbar: Union[muutils.parallel.ProgressBarFunction, Literal[-> List[~OutputType] )
a function to make it easier to sometimes parallelize an operation
parallel
is False
, then the function
will run in serial, running map(func, iterable)
parallel
is True
, then the function
will run in parallel, running in parallel with the maximum number of
processesparallel
is an int
, it must be greater
than 1, and the function will run in parallel with the number of
processes specified by parallel
the maximum number of processes is given by the
min(len(iterable), multiprocessing.cpu_count())
func : Callable[[InputType], OutputType]
function
passed to either map
or Pool.imap
iterable : Iterable[InputType]
iterable passed to
either map
or Pool.imap
parallel : bool | int
whether to run in parallel, and
how many processes to usepbar_kwargs : Dict[str, Any]
kwargs passed to the
progress bar functionList[OutputType]
a list of the output of
func
for each element in iterable
ValueError
: if parallel
is not a boolean
or an integer greater than 1ValueError
: if use_multiprocess=True
and
parallel=False
ImportError
: if use_multiprocess=True
and
multiprocess
is not availabledocs for
muutils
v0.8.7
decorator spinner_decorator
and context manager
SpinnerContext
to display a spinner
using the base Spinner
class while some code is
running.
DecoratedFunction
SpinnerConfig
SpinnerConfigArg
SPINNERS
Spinner
NoOpContextManager
SpinnerContext
spinner_decorator
muutils.spinner
decorator spinner_decorator
and context manager
SpinnerContext
to display a spinner
using the base Spinner
class while some code is
running.
DecoratedFunction = ~DecoratedFunction
Define a generic type for the decorated function
class SpinnerConfig:
SpinnerConfig
str] = <factory>, success: str = '✔️', fail: str = '❌') (working: List[
working: List[str]
success: str = '✔️'
fail: str = '❌'
def is_ascii
self) -> bool (
whether all characters are ascii
def eq_lens
self) -> bool (
whether all working characters are the same length
def is_valid
self) -> bool (
whether the spinner config is valid
def from_any
(
cls,str, List[str], muutils.spinner.SpinnerConfig, dict]
arg: Union[-> muutils.spinner.SpinnerConfig )
SpinnerConfigArg = typing.Union[str, typing.List[str], muutils.spinner.SpinnerConfig, dict]
SPINNERS: Dict[str, muutils.spinner.SpinnerConfig] = {'default': SpinnerConfig(working=['|', '/', '-', '\\'], success='#', fail='X'), 'dots': SpinnerConfig(working=['. ', '.. ', '...'], success='***', fail='xxx'), 'bars': SpinnerConfig(working=['| ', '|| ', '|||'], success='|||', fail='///'), 'arrows': SpinnerConfig(working=['<', '^', '>', 'v'], success='►', fail='✖'), 'arrows_2': SpinnerConfig(working=['←', '↖', '↑', '↗', '→', '↘', '↓', '↙'], success='→', fail='↯'), 'bouncing_bar': SpinnerConfig(working=['[ ]', '[= ]', '[== ]', '[=== ]', '[ ===]', '[ ==]', '[ =]'], success='[====]', fail='[XXXX]'), 'bar': SpinnerConfig(working=['[ ]', '[- ]', '[--]', '[ -]'], success='[==]', fail='[xx]'), 'bouncing_ball': SpinnerConfig(working=['( ● )', '( ● )', '( ● )', '( ● )', '( ●)', '( ● )', '( ● )', '( ● )', '( ● )', '(● )'], success='(●●●●●●)', fail='( ✖ )'), 'ooo': SpinnerConfig(working=['.', 'o', 'O', 'o'], success='O', fail='x'), 'braille': SpinnerConfig(working=['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'], success='⣿', fail='X'), 'clock': SpinnerConfig(working=['🕛', '🕐', '🕑', '🕒', '🕓', '🕔', '🕕', '🕖', '🕗', '🕘', '🕙', '🕚'], success='✔️', fail='❌'), 'hourglass': SpinnerConfig(working=['⏳', '⌛'], success='✔️', fail='❌'), 'square_corners': SpinnerConfig(working=['◰', '◳', '◲', '◱'], success='◼', fail='✖'), 'triangle': SpinnerConfig(working=['◢', '◣', '◤', '◥'], success='◆', fail='✖'), 'square_dot': SpinnerConfig(working=['⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽', '⣾'], success='⣿', fail='❌'), 'box_bounce': SpinnerConfig(working=['▌', '▀', '▐', '▄'], success='■', fail='✖'), 'hamburger': SpinnerConfig(working=['☱', '☲', '☴'], success='☰', fail='✖'), 'earth': SpinnerConfig(working=['🌍', '🌎', '🌏'], success='✔️', fail='❌'), 'growing_dots': SpinnerConfig(working=['⣀', '⣄', '⣤', '⣦', '⣶', '⣷', '⣿'], success='⣿', fail='✖'), 'dice': SpinnerConfig(working=['⚀', '⚁', '⚂', '⚃', '⚄', '⚅'], success='🎲', fail='✖'), 'wifi': SpinnerConfig(working=['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], success='✔️', fail='❌'), 'bounce': SpinnerConfig(working=['⠁', '⠂', '⠄', '⠂'], success='⠿', fail='⢿'), 'arc': SpinnerConfig(working=['◜', '◠', '◝', '◞', '◡', '◟'], success='○', fail='✖'), 'toggle': SpinnerConfig(working=['⊶', '⊷'], success='⊷', fail='⊗'), 'toggle2': SpinnerConfig(working=['▫', '▪'], success='▪', fail='✖'), 'toggle3': SpinnerConfig(working=['□', '■'], success='■', fail='✖'), 'toggle4': SpinnerConfig(working=['■', '□', '▪', '▫'], success='■', fail='✖'), 'toggle5': SpinnerConfig(working=['▮', '▯'], success='▮', fail='✖'), 'toggle7': SpinnerConfig(working=['⦾', '⦿'], success='⦿', fail='✖'), 'toggle8': SpinnerConfig(working=['◍', '◌'], success='◍', fail='✖'), 'toggle9': SpinnerConfig(working=['◉', '◎'], success='◉', fail='✖'), 'arrow2': SpinnerConfig(working=['⬆️ ', '↗️ ', '➡️ ', '↘️ ', '⬇️ ', '↙️ ', '⬅️ ', '↖️ '], success='➡️', fail='❌'), 'point': SpinnerConfig(working=['∙∙∙', '●∙∙', '∙●∙', '∙∙●', '∙∙∙'], success='●●●', fail='xxx'), 'layer': SpinnerConfig(working=['-', '=', '≡'], success='≡', fail='✖'), 'speaker': SpinnerConfig(working=['🔈 ', '🔉 ', '🔊 ', '🔉 '], success='🔊', fail='🔇'), 'orangePulse': SpinnerConfig(working=['🔸 ', '🔶 ', '🟠 ', '🟠 ', '🔷 '], success='🟠', fail='❌'), 'bluePulse': SpinnerConfig(working=['🔹 ', '🔷 ', '🔵 ', '🔵 ', '🔷 '], success='🔵', fail='❌'), 'satellite_signal': SpinnerConfig(working=['📡 ', '📡· ', '📡·· ', '📡···', '📡 ··', '📡 ·'], success='📡 ✔️ ', fail='📡 ❌ '), 'rocket_orbit': SpinnerConfig(working=['🌍🚀 ', '🌏 🚀 ', '🌎 🚀'], success='🌍 ✨', fail='🌍 💥'), 'ogham': SpinnerConfig(working=['ᚁ ', 'ᚂ ', 'ᚃ ', 'ᚄ', 'ᚅ'], success='᚛᚜', fail='✖'), 'eth': SpinnerConfig(working=['᛫', '፡', '፥', '፤', '፧', '።', '፨'], success='፠', fail='✖')}
class Spinner:
displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float
how often to update the spinner
display in seconds (defaults to 0.1
)initial_value : str
initial value to display with the
spinner (defaults to ""
)message : str
message to display with the spinner
(defaults to ""
)format_string : str
string to format the spinner with.
must have "\r"
prepended to clear the line. allowed keys
are spinner
, elapsed_time
,
message
, and value
(defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}"
)output_stream : TextIO
stream to write the spinner to
(defaults to sys.stdout
)format_string_when_updated : Union[bool,str]
whether to
use a different format string when the value is updated. if
True
, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False
)spinner_chars : Union[str, Sequence[str]]
sequence of
strings, or key to look up in SPINNER_CHARS
, to use as the
spinner characters (defaults to "default"
)spinner_complete : str
string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE
or "#"
)update_value(value: Any) -> None
update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
0.1)
time.sleep(f"Step {i+1}") spinner.update_value(
@spinner_decorator
def long_running_function():
for i in range(1):
0.1)
time.sleep(f"Step {i+1}")
spinner.update_value(return "Function completed"
Spinner
(*args,
str, List[str], muutils.spinner.SpinnerConfig, dict] = 'default',
config: Union[float = 0.1,
update_interval: str = '',
initial_value: str = '',
message: str = '\r{spinner} ({elapsed_time:.2f}s) {message}{value}',
format_string: <class 'TextIO'> = <_io.StringIO object>,
output_stream: str, bool] = False,
format_string_when_updated: Union[str, Sequence[str], NoneType] = None,
spinner_chars: Union[str] = None,
spinner_complete: Optional[**kwargs: Any
)
config: muutils.spinner.SpinnerConfig
format_string_when_updated: Optional[str]
format string to use when the value is updated
update_interval: float
message: str
current_value: Any
format_string: str
output_stream: <class 'TextIO'>
start_time: float
for measuring elapsed time
stop_spinner: threading.Event
to stop the spinner
spinner_thread: Optional[threading.Thread]
the thread running the spinner
value_changed: bool
whether the value has been updated since the last display
term_width: int
width of the terminal, for padding with spaces
state: Literal['initialized', 'running', 'success', 'fail']
def spin
self) -> None (
Function to run in a separate thread, displaying the spinner and optional information
def update_value
self, value: Any) -> None (
Update the current value displayed by the spinner
def start
self) -> None (
Start the spinner
def stop
self, failed: bool = False) -> None (
Stop the spinner
class NoOpContextManager(typing.ContextManager):
A context manager that does nothing.
NoOpContextManager
*args, **kwargs) (
class SpinnerContext(Spinner, typing.ContextManager):
displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float
how often to update the spinner
display in seconds (defaults to 0.1
)initial_value : str
initial value to display with the
spinner (defaults to ""
)message : str
message to display with the spinner
(defaults to ""
)format_string : str
string to format the spinner with.
must have "\r"
prepended to clear the line. allowed keys
are spinner
, elapsed_time
,
message
, and value
(defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}"
)output_stream : TextIO
stream to write the spinner to
(defaults to sys.stdout
)format_string_when_updated : Union[bool,str]
whether to
use a different format string when the value is updated. if
True
, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False
)spinner_chars : Union[str, Sequence[str]]
sequence of
strings, or key to look up in SPINNER_CHARS
, to use as the
spinner characters (defaults to "default"
)spinner_complete : str
string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE
or "#"
)update_value(value: Any) -> None
update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
0.1)
time.sleep(f"Step {i+1}") spinner.update_value(
@spinner_decorator
def long_running_function():
for i in range(1):
0.1)
time.sleep(f"Step {i+1}")
spinner.update_value(return "Function completed"
Spinner
config
format_string_when_updated
update_interval
message
current_value
format_string
output_stream
start_time
stop_spinner
spinner_thread
value_changed
term_width
state
spin
update_value
start
stop
def spinner_decorator
(*args,
str, List[str], muutils.spinner.SpinnerConfig, dict] = 'default',
config: Union[float = 0.1,
update_interval: str = '',
initial_value: str = '',
message: str = '{spinner} ({elapsed_time:.2f}s) {message}{value}',
format_string: <class 'TextIO'> = <_io.StringIO object>,
output_stream: str] = None,
mutable_kwarg_key: Optional[str, Sequence[str], NoneType] = None,
spinner_chars: Union[str] = None,
spinner_complete: Optional[**kwargs
-> Callable[[~DecoratedFunction], ~DecoratedFunction] )
displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float
how often to update the spinner
display in seconds (defaults to 0.1
)initial_value : str
initial value to display with the
spinner (defaults to ""
)message : str
message to display with the spinner
(defaults to ""
)format_string : str
string to format the spinner with.
must have "\r"
prepended to clear the line. allowed keys
are spinner
, elapsed_time
,
message
, and value
(defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}"
)output_stream : TextIO
stream to write the spinner to
(defaults to sys.stdout
)format_string_when_updated : Union[bool,str]
whether to
use a different format string when the value is updated. if
True
, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False
)spinner_chars : Union[str, Sequence[str]]
sequence of
strings, or key to look up in SPINNER_CHARS
, to use as the
spinner characters (defaults to "default"
)spinner_complete : str
string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE
or "#"
)update_value(value: Any) -> None
update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
0.1)
time.sleep(f"Step {i+1}") spinner.update_value(
@spinner_decorator
def long_running_function():
for i in range(1):
0.1)
time.sleep(f"Step {i+1}")
spinner.update_value(return "Function completed"
docs for
muutils
v0.8.7
StatCounter
class for counting and calculating
statistics on numbers
cleaner and more efficient than just using a Counter
or
array
muutils.statcounter
StatCounter
class for counting and calculating
statistics on numbers
cleaner and more efficient than just using a Counter
or
array
NumericSequence = typing.Sequence[typing.Union[float, int, ForwardRef('NumericSequence')]]
def universal_flatten
(float, int, Sequence[Union[float, int, ForwardRef('NumericSequence')]]]], float, int],
arr: Union[Sequence[Union[bool = True
require_rectangular: -> Sequence[Union[float, int, ForwardRef('NumericSequence')]] )
flattens any iterable
class StatCounter(collections.Counter):
Counter
, but with some stat calculation methods which
assume the keys are numerical
works best when the keys are int
s
def validate
self) -> bool (
validate the counter as being all floats or ints
def min
self) (
minimum value
def max
self) (
maximum value
def total
self) (
Sum of the counts
keys_sorted: list
return the keys
def percentile
self, p: float) (
return the value at the given percentile
this could be log time if we did binary search, but that would be a lot of added complexity
def median
self) -> float (
def mean
self) -> float (
return the mean of the values
def mode
self) -> float (
def std
self) -> float (
return the standard deviation of the values
def summary
(self,
= <function StatCounter.<lambda>>,
typecast: Callable *,
list[float]] = None
extra_percentiles: Optional[-> dict[str, typing.Union[float, int]] )
return a summary of the stats, without the raw data. human readable and small
def serialize
(self,
= <function StatCounter.<lambda>>,
typecast: Callable *,
list[float]] = None
extra_percentiles: Optional[-> dict )
return a json-serializable version of the counter
includes both the output of summary
and the raw
data:
{
"StatCounter": { <keys, values from raw data> },
"summary": self.summary(typecast, extra_percentiles=extra_percentiles),
}
### `def load` { #StatCounter.load }
```python
(cls, data: dict) -> muutils.statcounter.StatCounter
load from a the output of
<a href="#StatCounter.serialize">StatCounter.serialize</a>
def from_list_arrays
(
cls,
arr,= <class 'float'>
map_func: Callable -> muutils.statcounter.StatCounter )
calls map_func
on each element of
universal_flatten(arr)
docs for
muutils
v0.8.7
utilities for getting information about the system, see
SysInfo
class
muutils.sysinfo
utilities for getting information about the system, see
SysInfo
class
class SysInfo:
getters for various information about the system
def python
-> dict ()
details about python version
def pip
-> dict ()
installed packages info
def pytorch
-> dict ()
pytorch and cuda information
def platform
-> dict ()
def git_info
bool = False) -> dict (with_log:
def get_all
(
cls,tuple[str, ...]] = None,
include: Optional[tuple[str, ...] = ()
exclude: -> dict )
docs for
muutils
v0.8.7
COLORS
OutputFormat
SYMBOLS
SPARK_CHARS
array_info
generate_sparkline
DEFAULT_SETTINGS
array_summary
muutils.tensor_info
COLORS: Dict[str, Dict[str, str]] = {'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'reset': ''}}
OutputFormat = typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] = {'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}
Symbols for different formats
SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] = {'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}
characters for sparklines in different formats
def array_info
int = 5) -> Dict[str, Any] (A: Any, hist_bins:
Extract statistical information from an array-like object.
A : array-like
Array to analyze (numpy array or torch
tensor)Dict[str, Any]
Dictionary containing raw statistical
information with numeric valuesdef generate_sparkline
(
histogram: numpy.ndarray,format: Literal['unicode', 'latex', 'ascii'] = 'unicode',
bool = False
log_y: -> str )
Generate a sparkline visualization of the histogram.
histogram : np.ndarray
Histogram dataformat : Literal["unicode", "latex", "ascii"]
Output
format (defaults to "unicode"
)log_y : bool
Whether to use logarithmic y-scale
(defaults to False
)str
Sparkline visualization
DEFAULT_SETTINGS: Dict[str, Any] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': False, 'colored': False, 'as_list': False, 'eq_char': '='}
def array_summary
(
array,'unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>,
fmt: Literal[int = <muutils.tensor_info._UseDefaultType object>,
precision: bool = <muutils.tensor_info._UseDefaultType object>,
stats: bool = <muutils.tensor_info._UseDefaultType object>,
shape: bool = <muutils.tensor_info._UseDefaultType object>,
dtype: bool = <muutils.tensor_info._UseDefaultType object>,
device: bool = <muutils.tensor_info._UseDefaultType object>,
requires_grad: bool = <muutils.tensor_info._UseDefaultType object>,
sparkline: int = <muutils.tensor_info._UseDefaultType object>,
sparkline_bins: bool = <muutils.tensor_info._UseDefaultType object>,
sparkline_logy: bool = <muutils.tensor_info._UseDefaultType object>,
colored: str = <muutils.tensor_info._UseDefaultType object>,
eq_char: bool = <muutils.tensor_info._UseDefaultType object>
as_list: -> Union[str, List[str]] )
Format array information into a readable summary.
array
array-like object (numpy array or torch
tensor)precision : int
Decimal places (defaults to
2
)format : Literal["unicode", "latex", "ascii"]
Output
format (defaults to {default_fmt}
)stats : bool
Whether to include statistical info (μ, σ,
x̃) (defaults to True
)shape : bool
Whether to include shape info (defaults to
True
)dtype : bool
Whether to include dtype info (defaults to
True
)device : bool
Whether to include device info for torch
tensors (defaults to True
)requires_grad : bool
Whether to include requires_grad
info for torch tensors (defaults to True
)sparkline : bool
Whether to include a sparkline
visualization (defaults to False
)sparkline_width : int
Width of the sparkline (defaults
to 20
)sparkline_logy : bool
Whether to use logarithmic
y-scale for sparkline (defaults to False
)colored : bool
Whether to add color to output (defaults
to False
)as_list : bool
Whether to return as list of strings
instead of joined string (defaults to False
)Union[str, List[str]]
Formatted statistical summary,
either as string or list of stringsdocs for
muutils
v0.8.7
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE
: a mapping from python, numpy, and
torch types to jaxtyping
typesDTYPE_MAP
mapping string representations of types to
their typeTORCH_DTYPE_MAP
mapping string representations of types
to torch typescompare_state_dicts
for comparing two state dicts and
giving a detailed error message on whether if was keys, shapes, or
values that didn’t matchTYPE_TO_JAX_DTYPE
jaxtype_factory
ATensor
NDArray
numpy_to_torch_dtype
DTYPE_LIST
DTYPE_MAP
TORCH_DTYPE_MAP
TORCH_OPTIMIZERS_MAP
pad_tensor
lpad_tensor
rpad_tensor
pad_array
lpad_array
rpad_array
get_dict_shapes
string_dict_shapes
StateDictCompareError
StateDictKeysError
StateDictShapeError
StateDictValueError
compare_state_dicts
muutils.tensor_utils
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE
: a mapping from python, numpy, and
torch types to jaxtyping
typesDTYPE_MAP
mapping string representations of types to
their typeTORCH_DTYPE_MAP
mapping string representations of types
to torch typescompare_state_dicts
for comparing two state dicts and
giving a detailed error message on whether if was keys, shapes, or
values that didn’t matchTYPE_TO_JAX_DTYPE: dict = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool_'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.longlong'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}
dict mapping python, numpy, and torch types to jaxtyping
types
def jaxtype_factory
(str,
name: type,
array_type: =<class 'jaxtyping.Float'>,
default_jax_dtypestr] = ErrorMode.Warn
legacy_mode: Union[muutils.errormode.ErrorMode, -> type )
usage:
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
ATensor = <class 'muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray'>
NDArray = <class 'muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray'>
def numpy_to_torch_dtype
-> torch.dtype (dtype: Union[numpy.dtype, torch.dtype])
convert numpy dtype to torch dtype
DTYPE_LIST: list = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.longlong'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool_'>, <class 'numpy.float64'>, <class 'numpy.int64'>]
list of all the python, numpy, and torch numerical types I could think of
DTYPE_MAP: dict = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.longlong'>": <class 'numpy.longlong'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool_'>": <class 'numpy.bool_'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'float64': <class 'numpy.float64'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'longlong': <class 'numpy.longlong'>, 'uint8': <class 'numpy.uint8'>, 'bool_': <class 'numpy.bool_'>, 'bool': <class 'numpy.bool_'>}
mapping from string representations of types to their type
TORCH_DTYPE_MAP: dict = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int64, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.longlong'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool_'>": torch.bool, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'longlong': torch.int64, 'uint8': torch.uint8, 'bool_': torch.bool, 'bool': torch.bool}
mapping from string representations of types to specifically torch types
TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}
def pad_tensor
('dim1'],
tensor: jaxtyping.Shaped[Tensor, int,
padded_length: float = 0.0,
pad_value: bool = False
rpad: -> jaxtyping.Shaped[Tensor, 'padded_length'] )
pad a 1-d tensor on the left with pad_value to length
padded_length
set rpad = True
to pad on the right instead
def lpad_tensor
(
tensor: torch.Tensor,int,
padded_length: float = 0.0
pad_value: -> torch.Tensor )
pad a 1-d tensor on the left with pad_value to length
padded_length
def rpad_tensor
(
tensor: torch.Tensor,int,
pad_length: float = 0.0
pad_value: -> torch.Tensor )
pad a 1-d tensor on the right with pad_value to length
pad_length
def pad_array
('dim1'],
array: jaxtyping.Shaped[ndarray, int,
padded_length: float = 0.0,
pad_value: bool = False
rpad: -> jaxtyping.Shaped[ndarray, 'padded_length'] )
pad a 1-d array on the left with pad_value to length
padded_length
set rpad = True
to pad on the right instead
def lpad_array
(
array: numpy.ndarray,int,
padded_length: float = 0.0
pad_value: -> numpy.ndarray )
pad a 1-d array on the left with pad_value to length
padded_length
def rpad_array
(
array: numpy.ndarray,int,
pad_length: float = 0.0
pad_value: -> numpy.ndarray )
pad a 1-d array on the right with pad_value to length
pad_length
def get_dict_shapes
dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]] (d:
given a state dict or cache dict, compute the shapes and put them in a nested dict
def string_dict_shapes
dict[str, torch.Tensor]) -> str (d:
printable version of get_dict_shapes
class StateDictCompareError(builtins.AssertionError):
raised when state dicts don’t match
class StateDictKeysError(StateDictCompareError):
raised when state dict keys don’t match
class StateDictShapeError(StateDictCompareError):
raised when state dict shapes don’t match
class StateDictValueError(StateDictCompareError):
raised when state dict values don’t match
def compare_state_dicts
(dict,
d1: dict,
d2: float = 1e-05,
rtol: float = 1e-08,
atol: bool = True
verbose: -> None )
compare two dicts of tensors
d1 : dict
d2 : dict
rtol : float
(defaults to 1e-5
)atol : float
(defaults to 1e-8
)verbose : bool
(defaults to True
)StateDictKeysError
: keys don’t matchStateDictShapeError
: shapes don’t match (but keys
do)StateDictValueError
: values don’t match (but keys and
shapes do)docs for
muutils
v0.8.7
timeit_fancy
is just a fancier version of timeit with
more options
muutils.timeit_fancy
timeit_fancy
is just a fancier version of timeit with
more options
class FancyTimeitResult(typing.NamedTuple):
return type of timeit_fancy
FancyTimeitResult
('StatCounter'),
timings: ForwardRef('T'),
return_value: ForwardRef('Union[pstats.Stats, None]')
profile: ForwardRef( )
Create new instance of FancyTimeitResult(timings, return_value, profile)
timings: muutils.statcounter.StatCounter
Alias for field number 0
return_value: ~T
Alias for field number 1
profile: Optional[pstats.Stats]
Alias for field number 2
def timeit_fancy
(~T], str],
cmd: Union[Callable[[], str, Callable[[], Any]] = <function <lambda>>,
setup: Union[int = 5,
repeats: dict[str, Any]] = None,
namespace: Optional[bool = True,
get_return: bool = False
do_profiling: -> muutils.timeit_fancy.FancyTimeitResult )
Wrapper for timeit
to get the fastest run of a callable
with more customization options.
Approximates the functionality of the %timeit magic or command line interface in a Python callable.
cmd: Callable[[], T] | str
The callable to time. If a
string, it will be passed to timeit.Timer
as the
stmt
argument.setup: str
The setup code to run before
cmd
. If a string, it will be passed to
timeit.Timer
as the setup
argument.repeats: int
The number of times to run
cmd
to get a reliable measurement.namespace: dict[str, Any]
Passed to
timeit.Timer
constructor. If cmd
or
setup
use local or global variables, they must be passed
here. See timeit
documentation for details.get_return: bool
Whether to pass the value returned
from cmd
. If True, the return value will be appended in a
tuple with execution time. This is for speed and convenience so that
cmd
doesn’t need to be run again in the calling scope if
the return values are needed. (default: False
)do_profiling: bool
Whether to return a
pstats.Stats
object in addition to the time and return
value. (default: False
)FancyTimeitResult
, which is a NamedTuple with the
following fields:
time: float
The time in seconds it took to run
cmd
the minimum number of times to get a reliable
measurement.return_value: T|None
The return value of
cmd
if get_return
is True
,
otherwise None
.profile: pstats.Stats|None
A pstats.Stats
object if do_profiling
is True
, otherwise
None
.docs for
muutils
v0.8.7
experimental utility for validating types in python, see
validate_type
GenericAliasTypes
IncorrectTypeException
TypeHintNotImplementedError
InvalidGenericAliasError
validate_type
get_fn_allowed_kwargs
muutils.validate_type
experimental utility for validating types in python, see
validate_type
GenericAliasTypes: tuple = (<class 'types.GenericAlias'>, <class 'typing._GenericAlias'>, <class 'typing._UnionGenericAlias'>, <class 'typing._BaseGenericAlias'>)
class IncorrectTypeException(builtins.TypeError):
Inappropriate argument type.
class TypeHintNotImplementedError(builtins.NotImplementedError):
Method or function hasn’t been implemented yet.
class InvalidGenericAliasError(builtins.TypeError):
Inappropriate argument type.
def validate_type
bool = False) -> bool (value: Any, expected_type: Any, do_except:
Validate that a value
is of the
expected_type
value
: the value to check the type ofexpected_type
: the type to check against. Not all types
are supporteddo_except
: if True
, raise an exception if
the type is incorrect (instead of returning False
)
(default: False
)bool
: True
if the value is of the expected
type, False
otherwise.IncorrectTypeException(TypeError)
: if the type is
incorrect and do_except
is True
TypeHintNotImplementedError(NotImplementedError)
: if
the type hint is not implementedInvalidGenericAliasError(TypeError)
: if the generic
alias is invaliduse typeguard
for a more robust solution:
https://github.com/agronholm/typeguard
def get_fn_allowed_kwargs
-> Set[str] (fn: Callable)
Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.