Coverage for muutils/misc/func.py: 84%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from __future__ import annotations 

2import functools 

3import sys 

4from types import CodeType 

5import warnings 

6from typing import Any, Callable, Tuple, cast, TypeVar 

7 

8try: 

9 if sys.version_info >= (3, 11): 

10 # 3.11+ 

11 from typing import Unpack, TypeVarTuple, ParamSpec 

12 else: 

13 # 3.9+ 

14 from typing_extensions import Unpack, TypeVarTuple, ParamSpec # type: ignore[assignment] 

15except ImportError: 

16 warnings.warn( 

17 "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work" 

18 ) 

19 ParamSpec = TypeVar # type: ignore 

20 Unpack = Any # type: ignore 

21 TypeVarTuple = TypeVar # type: ignore 

22 

23 

24from muutils.errormode import ErrorMode 

25 

26warnings.warn("muutils.misc.func is experimental, use with caution") 

27 

28ReturnType = TypeVar("ReturnType") 

29T_kwarg = TypeVar("T_kwarg") 

30T_process_in = TypeVar("T_process_in") 

31T_process_out = TypeVar("T_process_out") 

32 

33FuncParams = ParamSpec("FuncParams") 

34FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap") 

35 

36 

37def process_kwarg( 

38 kwarg_name: str, 

39 processor: Callable[[T_process_in], T_process_out], 

40) -> Callable[ 

41 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType] 

42]: 

43 """Decorator that applies a processor to a keyword argument. 

44 

45 The underlying function is expected to have a keyword argument 

46 (with name `kwarg_name`) of type `T_out`, but the caller provides 

47 a value of type `T_in` that is converted via `processor`. 

48 

49 # Parameters: 

50 - `kwarg_name : str` 

51 The name of the keyword argument to process. 

52 - `processor : Callable[[T_in], T_out]` 

53 A callable that converts the input value (`T_in`) into the 

54 type expected by the function (`T_out`). 

55 

56 # Returns: 

57 - A decorator that converts a function of type 

58 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`) 

59 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`). 

60 """ 

61 

62 def decorator( 

63 func: Callable[FuncParamsPreWrap, ReturnType], 

64 ) -> Callable[FuncParams, ReturnType]: 

65 @functools.wraps(func) 

66 def wrapper(*args: Any, **kwargs: Any) -> ReturnType: 

67 if kwarg_name in kwargs: 

68 # Convert the caller’s value (of type T_in) to T_out 

69 kwargs[kwarg_name] = processor(kwargs[kwarg_name]) 

70 return func(*args, **kwargs) # type: ignore[arg-type] 

71 

72 return cast(Callable[FuncParams, ReturnType], wrapper) 

73 

74 return decorator 

75 

76 

77@process_kwarg("action", ErrorMode.from_any) 

78def validate_kwarg( 

79 kwarg_name: str, 

80 validator: Callable[[T_kwarg], bool], 

81 description: str | None = None, 

82 action: ErrorMode = ErrorMode.EXCEPT, 

83) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 

84 """Decorator that validates a specific keyword argument. 

85 

86 # Parameters: 

87 - `kwarg_name : str` 

88 The name of the keyword argument to validate. 

89 - `validator : Callable[[Any], bool]` 

90 A callable that returns True if the keyword argument is valid. 

91 - `description : str | None` 

92 A message template if validation fails. 

93 - `action : str` 

94 Either `"raise"` (default) or `"warn"`. 

95 

96 # Returns: 

97 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 

98 A decorator that validates the keyword argument. 

99 

100 # Modifies: 

101 - If validation fails and `action=="warn"`, emits a warning. 

102 Otherwise, raises a ValueError. 

103 

104 # Usage: 

105 

106 ```python 

107 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}") 

108 def my_func(x: int) -> int: 

109 return x 

110 

111 assert my_func(x=1) == 1 

112 ``` 

113 

114 # Raises: 

115 - `ValueError` if validation fails and `action == "raise"`. 

116 """ 

117 

118 def decorator( 

119 func: Callable[FuncParams, ReturnType], 

120 ) -> Callable[FuncParams, ReturnType]: 

121 @functools.wraps(func) 

122 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: 

123 if kwarg_name in kwargs: 

124 value: Any = kwargs[kwarg_name] 

125 if not validator(value): 

126 msg: str = ( 

127 description.format(kwarg_name=kwarg_name, value=value) 

128 if description 

129 else f"Validation failed for keyword '{kwarg_name}' with value {value}" 

130 ) 

131 if action == "warn": 

132 warnings.warn(msg, UserWarning) 

133 else: 

134 raise ValueError(msg) 

135 return func(*args, **kwargs) 

136 

137 return cast(Callable[FuncParams, ReturnType], wrapper) 

138 

139 return decorator 

140 

141 

142def replace_kwarg( 

143 kwarg_name: str, 

144 check: Callable[[T_kwarg], bool], 

145 replacement_value: T_kwarg, 

146 replace_if_missing: bool = False, 

147) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 

148 """Decorator that replaces a specific keyword argument value by identity comparison. 

149 

150 # Parameters: 

151 - `kwarg_name : str` 

152 The name of the keyword argument to replace. 

153 - `check : Callable[[T_kwarg], bool]` 

154 A callable that returns True if the keyword argument should be replaced. 

155 - `replacement_value : T_kwarg` 

156 The value to replace with. 

157 - `replace_if_missing : bool` 

158 If True, replaces the keyword argument even if it's missing. 

159 

160 # Returns: 

161 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 

162 A decorator that replaces the keyword argument value. 

163 

164 # Modifies: 

165 - Updates `kwargs[kwarg_name]` if its value is `default_value`. 

166 

167 # Usage: 

168 

169 ```python 

170 @replace_kwarg("x", None, "default_string") 

171 def my_func(*, x: str | None = None) -> str: 

172 return x 

173 

174 assert my_func(x=None) == "default_string" 

175 ``` 

176 """ 

177 

178 def decorator( 

179 func: Callable[FuncParams, ReturnType], 

180 ) -> Callable[FuncParams, ReturnType]: 

181 @functools.wraps(func) 

182 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: 

183 if kwarg_name in kwargs: 

184 # TODO: no way to type hint this, I think 

185 if check(kwargs[kwarg_name]): # type: ignore[arg-type] 

186 kwargs[kwarg_name] = replacement_value 

187 elif replace_if_missing and kwarg_name not in kwargs: 

188 kwargs[kwarg_name] = replacement_value 

189 return func(*args, **kwargs) 

190 

191 return cast(Callable[FuncParams, ReturnType], wrapper) 

192 

193 return decorator 

194 

195 

196def is_none(value: Any) -> bool: 

197 return value is None 

198 

199 

200def always_true(value: Any) -> bool: 

201 return True 

202 

203 

204def always_false(value: Any) -> bool: 

205 return False 

206 

207 

208def format_docstring( 

209 **fmt_kwargs: Any, 

210) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 

211 """Decorator that formats a function's docstring with the provided keyword arguments.""" 

212 

213 def decorator( 

214 func: Callable[FuncParams, ReturnType], 

215 ) -> Callable[FuncParams, ReturnType]: 

216 if func.__doc__ is not None: 

217 func.__doc__ = func.__doc__.format(**fmt_kwargs) 

218 return func 

219 

220 return decorator 

221 

222 

223# TODO: no way to make the type system understand this afaik 

224LambdaArgs = TypeVarTuple("LambdaArgs") 

225LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...]) 

226 

227 

228def typed_lambda( 

229 fn: Callable[[Unpack[LambdaArgs]], ReturnType], 

230 in_types: LambdaArgsTypes, 

231 out_type: type[ReturnType], 

232) -> Callable[[Unpack[LambdaArgs]], ReturnType]: 

233 """Wraps a lambda function with type hints. 

234 

235 # Parameters: 

236 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]` 

237 The lambda function to wrap. 

238 - `in_types : tuple[type, ...]` 

239 Tuple of input types. 

240 - `out_type : type[ReturnType]` 

241 The output type. 

242 

243 # Returns: 

244 - `Callable[..., ReturnType]` 

245 A new function with annotations matching the given signature. 

246 

247 # Usage: 

248 

249 ```python 

250 add = typed_lambda(lambda x, y: x + y, (int, int), int) 

251 assert add(1, 2) == 3 

252 ``` 

253 

254 # Raises: 

255 - `ValueError` if the number of input types doesn't match the lambda's parameters. 

256 """ 

257 code: CodeType = fn.__code__ 

258 n_params: int = code.co_argcount 

259 

260 if len(in_types) != n_params: 

261 raise ValueError( 

262 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})" 

263 ) 

264 

265 param_names: tuple[str, ...] = code.co_varnames[:n_params] 

266 annotations: dict[str, type] = { # type: ignore[var-annotated] 

267 name: typ 

268 for name, typ in zip(param_names, in_types) # type: ignore[arg-type] 

269 } 

270 annotations["return"] = out_type 

271 

272 @functools.wraps(fn) 

273 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: 

274 return fn(*args) 

275 

276 wrapped.__annotations__ = annotations 

277 return wrapped