Coverage for muutils/misc/func.py: 84%
86 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1from __future__ import annotations
2import functools
3import sys
4from types import CodeType
5import warnings
6from typing import Any, Callable, Tuple, cast, TypeVar
8try:
9 if sys.version_info >= (3, 11):
10 # 3.11+
11 from typing import Unpack, TypeVarTuple, ParamSpec
12 else:
13 # 3.9+
14 from typing_extensions import Unpack, TypeVarTuple, ParamSpec # type: ignore[assignment]
15except ImportError:
16 warnings.warn(
17 "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work"
18 )
19 ParamSpec = TypeVar # type: ignore
20 Unpack = Any # type: ignore
21 TypeVarTuple = TypeVar # type: ignore
24from muutils.errormode import ErrorMode
26warnings.warn("muutils.misc.func is experimental, use with caution")
28ReturnType = TypeVar("ReturnType")
29T_kwarg = TypeVar("T_kwarg")
30T_process_in = TypeVar("T_process_in")
31T_process_out = TypeVar("T_process_out")
33FuncParams = ParamSpec("FuncParams")
34FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap")
37def process_kwarg(
38 kwarg_name: str,
39 processor: Callable[[T_process_in], T_process_out],
40) -> Callable[
41 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType]
42]:
43 """Decorator that applies a processor to a keyword argument.
45 The underlying function is expected to have a keyword argument
46 (with name `kwarg_name`) of type `T_out`, but the caller provides
47 a value of type `T_in` that is converted via `processor`.
49 # Parameters:
50 - `kwarg_name : str`
51 The name of the keyword argument to process.
52 - `processor : Callable[[T_in], T_out]`
53 A callable that converts the input value (`T_in`) into the
54 type expected by the function (`T_out`).
56 # Returns:
57 - A decorator that converts a function of type
58 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`)
59 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`).
60 """
62 def decorator(
63 func: Callable[FuncParamsPreWrap, ReturnType],
64 ) -> Callable[FuncParams, ReturnType]:
65 @functools.wraps(func)
66 def wrapper(*args: Any, **kwargs: Any) -> ReturnType:
67 if kwarg_name in kwargs:
68 # Convert the caller’s value (of type T_in) to T_out
69 kwargs[kwarg_name] = processor(kwargs[kwarg_name])
70 return func(*args, **kwargs) # type: ignore[arg-type]
72 return cast(Callable[FuncParams, ReturnType], wrapper)
74 return decorator
77@process_kwarg("action", ErrorMode.from_any)
78def validate_kwarg(
79 kwarg_name: str,
80 validator: Callable[[T_kwarg], bool],
81 description: str | None = None,
82 action: ErrorMode = ErrorMode.EXCEPT,
83) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
84 """Decorator that validates a specific keyword argument.
86 # Parameters:
87 - `kwarg_name : str`
88 The name of the keyword argument to validate.
89 - `validator : Callable[[Any], bool]`
90 A callable that returns True if the keyword argument is valid.
91 - `description : str | None`
92 A message template if validation fails.
93 - `action : str`
94 Either `"raise"` (default) or `"warn"`.
96 # Returns:
97 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
98 A decorator that validates the keyword argument.
100 # Modifies:
101 - If validation fails and `action=="warn"`, emits a warning.
102 Otherwise, raises a ValueError.
104 # Usage:
106 ```python
107 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
108 def my_func(x: int) -> int:
109 return x
111 assert my_func(x=1) == 1
112 ```
114 # Raises:
115 - `ValueError` if validation fails and `action == "raise"`.
116 """
118 def decorator(
119 func: Callable[FuncParams, ReturnType],
120 ) -> Callable[FuncParams, ReturnType]:
121 @functools.wraps(func)
122 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:
123 if kwarg_name in kwargs:
124 value: Any = kwargs[kwarg_name]
125 if not validator(value):
126 msg: str = (
127 description.format(kwarg_name=kwarg_name, value=value)
128 if description
129 else f"Validation failed for keyword '{kwarg_name}' with value {value}"
130 )
131 if action == "warn":
132 warnings.warn(msg, UserWarning)
133 else:
134 raise ValueError(msg)
135 return func(*args, **kwargs)
137 return cast(Callable[FuncParams, ReturnType], wrapper)
139 return decorator
142def replace_kwarg(
143 kwarg_name: str,
144 check: Callable[[T_kwarg], bool],
145 replacement_value: T_kwarg,
146 replace_if_missing: bool = False,
147) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
148 """Decorator that replaces a specific keyword argument value by identity comparison.
150 # Parameters:
151 - `kwarg_name : str`
152 The name of the keyword argument to replace.
153 - `check : Callable[[T_kwarg], bool]`
154 A callable that returns True if the keyword argument should be replaced.
155 - `replacement_value : T_kwarg`
156 The value to replace with.
157 - `replace_if_missing : bool`
158 If True, replaces the keyword argument even if it's missing.
160 # Returns:
161 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
162 A decorator that replaces the keyword argument value.
164 # Modifies:
165 - Updates `kwargs[kwarg_name]` if its value is `default_value`.
167 # Usage:
169 ```python
170 @replace_kwarg("x", None, "default_string")
171 def my_func(*, x: str | None = None) -> str:
172 return x
174 assert my_func(x=None) == "default_string"
175 ```
176 """
178 def decorator(
179 func: Callable[FuncParams, ReturnType],
180 ) -> Callable[FuncParams, ReturnType]:
181 @functools.wraps(func)
182 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:
183 if kwarg_name in kwargs:
184 # TODO: no way to type hint this, I think
185 if check(kwargs[kwarg_name]): # type: ignore[arg-type]
186 kwargs[kwarg_name] = replacement_value
187 elif replace_if_missing and kwarg_name not in kwargs:
188 kwargs[kwarg_name] = replacement_value
189 return func(*args, **kwargs)
191 return cast(Callable[FuncParams, ReturnType], wrapper)
193 return decorator
196def is_none(value: Any) -> bool:
197 return value is None
200def always_true(value: Any) -> bool:
201 return True
204def always_false(value: Any) -> bool:
205 return False
208def format_docstring(
209 **fmt_kwargs: Any,
210) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
211 """Decorator that formats a function's docstring with the provided keyword arguments."""
213 def decorator(
214 func: Callable[FuncParams, ReturnType],
215 ) -> Callable[FuncParams, ReturnType]:
216 if func.__doc__ is not None:
217 func.__doc__ = func.__doc__.format(**fmt_kwargs)
218 return func
220 return decorator
223# TODO: no way to make the type system understand this afaik
224LambdaArgs = TypeVarTuple("LambdaArgs")
225LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...])
228def typed_lambda(
229 fn: Callable[[Unpack[LambdaArgs]], ReturnType],
230 in_types: LambdaArgsTypes,
231 out_type: type[ReturnType],
232) -> Callable[[Unpack[LambdaArgs]], ReturnType]:
233 """Wraps a lambda function with type hints.
235 # Parameters:
236 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]`
237 The lambda function to wrap.
238 - `in_types : tuple[type, ...]`
239 Tuple of input types.
240 - `out_type : type[ReturnType]`
241 The output type.
243 # Returns:
244 - `Callable[..., ReturnType]`
245 A new function with annotations matching the given signature.
247 # Usage:
249 ```python
250 add = typed_lambda(lambda x, y: x + y, (int, int), int)
251 assert add(1, 2) == 3
252 ```
254 # Raises:
255 - `ValueError` if the number of input types doesn't match the lambda's parameters.
256 """
257 code: CodeType = fn.__code__
258 n_params: int = code.co_argcount
260 if len(in_types) != n_params:
261 raise ValueError(
262 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})"
263 )
265 param_names: tuple[str, ...] = code.co_varnames[:n_params]
266 annotations: dict[str, type] = { # type: ignore[var-annotated]
267 name: typ
268 for name, typ in zip(param_names, in_types) # type: ignore[arg-type]
269 }
270 annotations["return"] = out_type
272 @functools.wraps(fn)
273 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType:
274 return fn(*args)
276 wrapped.__annotations__ = annotations
277 return wrapped