Coverage for muutils / misc / func.py: 84%
86 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
1from __future__ import annotations
2import functools
3import sys
4from types import CodeType
5import warnings
6from typing import Any, Callable, Tuple, cast, TypeVar
8# TODO: we do a lot of type weirdness here that basedpyright doesn't like
9# pyright: reportInvalidTypeForm=false
11try:
12 if sys.version_info >= (3, 11):
13 # 3.11+
14 from typing import Unpack, TypeVarTuple, ParamSpec
15 else:
16 # 3.9+
17 from typing_extensions import Unpack, TypeVarTuple, ParamSpec # type: ignore[assignment]
18except ImportError:
19 warnings.warn(
20 "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work"
21 )
22 ParamSpec = TypeVar # type: ignore
23 Unpack = Any # type: ignore
24 TypeVarTuple = TypeVar # type: ignore
27from muutils.errormode import ErrorMode
29warnings.warn("muutils.misc.func is experimental, use with caution")
31ReturnType = TypeVar("ReturnType")
32T_kwarg = TypeVar("T_kwarg")
33T_process_in = TypeVar("T_process_in")
34T_process_out = TypeVar("T_process_out")
36FuncParams = ParamSpec("FuncParams")
37FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap")
40def process_kwarg(
41 kwarg_name: str,
42 processor: Callable[[T_process_in], T_process_out],
43) -> Callable[
44 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType]
45]:
46 """Decorator that applies a processor to a keyword argument.
48 The underlying function is expected to have a keyword argument
49 (with name `kwarg_name`) of type `T_out`, but the caller provides
50 a value of type `T_in` that is converted via `processor`.
52 # Parameters:
53 - `kwarg_name : str`
54 The name of the keyword argument to process.
55 - `processor : Callable[[T_in], T_out]`
56 A callable that converts the input value (`T_in`) into the
57 type expected by the function (`T_out`).
59 # Returns:
60 - A decorator that converts a function of type
61 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`)
62 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`).
63 """
65 def decorator(
66 func: Callable[FuncParamsPreWrap, ReturnType],
67 ) -> Callable[FuncParams, ReturnType]:
68 @functools.wraps(func)
69 def wrapper(*args: Any, **kwargs: Any) -> ReturnType:
70 if kwarg_name in kwargs:
71 # Convert the caller’s value (of type T_in) to T_out
72 kwargs[kwarg_name] = processor(kwargs[kwarg_name])
73 return func(*args, **kwargs) # type: ignore[arg-type]
75 return cast(Callable[FuncParams, ReturnType], wrapper)
77 return decorator
80# TYPING: error: Argument of type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" cannot be assigned to parameter of type "() -> ReturnType@process_kwarg"
81# Type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" is not assignable to type "() -> ReturnType@process_kwarg"
82# Extra parameter "kwarg_name"
83# Extra parameter "validator" (reportArgumentType)
84@process_kwarg("action", ErrorMode.from_any) # pyright: ignore[reportArgumentType]
85def validate_kwarg(
86 kwarg_name: str,
87 validator: Callable[[T_kwarg], bool],
88 description: str | None = None,
89 action: ErrorMode = ErrorMode.EXCEPT,
90) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
91 """Decorator that validates a specific keyword argument.
93 # Parameters:
94 - `kwarg_name : str`
95 The name of the keyword argument to validate.
96 - `validator : Callable[[Any], bool]`
97 A callable that returns True if the keyword argument is valid.
98 - `description : str | None`
99 A message template if validation fails.
100 - `action : str`
101 Either `"raise"` (default) or `"warn"`.
103 # Returns:
104 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
105 A decorator that validates the keyword argument.
107 # Modifies:
108 - If validation fails and `action=="warn"`, emits a warning.
109 Otherwise, raises a ValueError.
111 # Usage:
113 ```python
114 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
115 def my_func(x: int) -> int:
116 return x
118 assert my_func(x=1) == 1
119 ```
121 # Raises:
122 - `ValueError` if validation fails and `action == "raise"`.
123 """
125 def decorator(
126 func: Callable[FuncParams, ReturnType],
127 ) -> Callable[FuncParams, ReturnType]:
128 @functools.wraps(func)
129 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType]
130 if kwarg_name in kwargs:
131 value: Any = kwargs[kwarg_name]
132 if not validator(value): # ty: ignore[invalid-argument-type]
133 msg: str = (
134 description.format(kwarg_name=kwarg_name, value=value)
135 if description
136 else f"Validation failed for keyword '{kwarg_name}' with value {value}"
137 )
138 if action == "warn":
139 warnings.warn(msg, UserWarning)
140 else:
141 raise ValueError(msg)
142 return func(*args, **kwargs)
144 return cast(Callable[FuncParams, ReturnType], wrapper)
146 return decorator
149def replace_kwarg(
150 kwarg_name: str,
151 check: Callable[[T_kwarg], bool],
152 replacement_value: T_kwarg,
153 replace_if_missing: bool = False,
154) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
155 """Decorator that replaces a specific keyword argument value by identity comparison.
157 # Parameters:
158 - `kwarg_name : str`
159 The name of the keyword argument to replace.
160 - `check : Callable[[T_kwarg], bool]`
161 A callable that returns True if the keyword argument should be replaced.
162 - `replacement_value : T_kwarg`
163 The value to replace with.
164 - `replace_if_missing : bool`
165 If True, replaces the keyword argument even if it's missing.
167 # Returns:
168 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
169 A decorator that replaces the keyword argument value.
171 # Modifies:
172 - Updates `kwargs[kwarg_name]` if its value is `default_value`.
174 # Usage:
176 ```python
177 @replace_kwarg("x", None, "default_string")
178 def my_func(*, x: str | None = None) -> str:
179 return x
181 assert my_func(x=None) == "default_string"
182 ```
183 """
185 def decorator(
186 func: Callable[FuncParams, ReturnType],
187 ) -> Callable[FuncParams, ReturnType]:
188 @functools.wraps(func)
189 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType]
190 if kwarg_name in kwargs:
191 # TODO: no way to type hint this, I think
192 if check(kwargs[kwarg_name]): # type: ignore[arg-type]
193 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment]
194 elif replace_if_missing and kwarg_name not in kwargs:
195 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment]
196 return func(*args, **kwargs)
198 return cast(Callable[FuncParams, ReturnType], wrapper)
200 return decorator
203def is_none(value: Any) -> bool:
204 return value is None
207def always_true(value: Any) -> bool:
208 return True
211def always_false(value: Any) -> bool:
212 return False
215def format_docstring(
216 **fmt_kwargs: Any,
217) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
218 """Decorator that formats a function's docstring with the provided keyword arguments."""
220 def decorator(
221 func: Callable[FuncParams, ReturnType],
222 ) -> Callable[FuncParams, ReturnType]:
223 if func.__doc__ is not None:
224 func.__doc__ = func.__doc__.format(**fmt_kwargs)
225 return func
227 return decorator
230# TODO: no way to make the type system understand this afaik
231LambdaArgs = TypeVarTuple("LambdaArgs")
232LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...])
235def typed_lambda( # pyright: ignore[reportUnknownParameterType]
236 fn: Callable[[Unpack[LambdaArgs]], ReturnType],
237 in_types: LambdaArgsTypes, # pyright: ignore[reportInvalidTypeVarUse]
238 out_type: type[ReturnType],
239) -> Callable[[Unpack[LambdaArgs]], ReturnType]:
240 """Wraps a lambda function with type hints.
242 # Parameters:
243 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]`
244 The lambda function to wrap.
245 - `in_types : tuple[type, ...]`
246 Tuple of input types.
247 - `out_type : type[ReturnType]`
248 The output type.
250 # Returns:
251 - `Callable[..., ReturnType]`
252 A new function with annotations matching the given signature.
254 # Usage:
256 ```python
257 add = typed_lambda(lambda x, y: x + y, (int, int), int)
258 assert add(1, 2) == 3
259 ```
261 # Raises:
262 - `ValueError` if the number of input types doesn't match the lambda's parameters.
263 """
264 # it will just error here if fn.__code__ doesn't exist
265 code: CodeType = fn.__code__ # type: ignore[unresolved-attribute]
266 n_params: int = code.co_argcount
268 if len(in_types) != n_params:
269 raise ValueError(
270 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})"
271 )
273 param_names: tuple[str, ...] = code.co_varnames[:n_params]
274 annotations: dict[str, type] = { # type: ignore[var-annotated]
275 name: typ
276 for name, typ in zip(param_names, in_types) # type: ignore[arg-type]
277 }
278 annotations["return"] = out_type
280 @functools.wraps(fn)
281 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: # pyright: ignore[reportUnknownParameterType]
282 return fn(*args)
284 wrapped.__annotations__ = annotations
285 return wrapped