Coverage for muutils / misc / func.py: 84%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1from __future__ import annotations 

2import functools 

3import sys 

4from types import CodeType 

5import warnings 

6from typing import Any, Callable, Tuple, cast, TypeVar 

7 

8# TODO: we do a lot of type weirdness here that basedpyright doesn't like 

9# pyright: reportInvalidTypeForm=false 

10 

11try: 

12 if sys.version_info >= (3, 11): 

13 # 3.11+ 

14 from typing import Unpack, TypeVarTuple, ParamSpec 

15 else: 

16 # 3.9+ 

17 from typing_extensions import Unpack, TypeVarTuple, ParamSpec # type: ignore[assignment] 

18except ImportError: 

19 warnings.warn( 

20 "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work" 

21 ) 

22 ParamSpec = TypeVar # type: ignore 

23 Unpack = Any # type: ignore 

24 TypeVarTuple = TypeVar # type: ignore 

25 

26 

27from muutils.errormode import ErrorMode 

28 

29warnings.warn("muutils.misc.func is experimental, use with caution") 

30 

31ReturnType = TypeVar("ReturnType") 

32T_kwarg = TypeVar("T_kwarg") 

33T_process_in = TypeVar("T_process_in") 

34T_process_out = TypeVar("T_process_out") 

35 

36FuncParams = ParamSpec("FuncParams") 

37FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap") 

38 

39 

40def process_kwarg( 

41 kwarg_name: str, 

42 processor: Callable[[T_process_in], T_process_out], 

43) -> Callable[ 

44 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType] 

45]: 

46 """Decorator that applies a processor to a keyword argument. 

47 

48 The underlying function is expected to have a keyword argument 

49 (with name `kwarg_name`) of type `T_out`, but the caller provides 

50 a value of type `T_in` that is converted via `processor`. 

51 

52 # Parameters: 

53 - `kwarg_name : str` 

54 The name of the keyword argument to process. 

55 - `processor : Callable[[T_in], T_out]` 

56 A callable that converts the input value (`T_in`) into the 

57 type expected by the function (`T_out`). 

58 

59 # Returns: 

60 - A decorator that converts a function of type 

61 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`) 

62 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`). 

63 """ 

64 

65 def decorator( 

66 func: Callable[FuncParamsPreWrap, ReturnType], 

67 ) -> Callable[FuncParams, ReturnType]: 

68 @functools.wraps(func) 

69 def wrapper(*args: Any, **kwargs: Any) -> ReturnType: 

70 if kwarg_name in kwargs: 

71 # Convert the caller’s value (of type T_in) to T_out 

72 kwargs[kwarg_name] = processor(kwargs[kwarg_name]) 

73 return func(*args, **kwargs) # type: ignore[arg-type] 

74 

75 return cast(Callable[FuncParams, ReturnType], wrapper) 

76 

77 return decorator 

78 

79 

80# TYPING: error: Argument of type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" cannot be assigned to parameter of type "() -> ReturnType@process_kwarg" 

81# Type "(kwarg_name: str, validator: (T_kwarg@validate_kwarg) -> bool, description: str | None = None, action: ErrorMode = ErrorMode.EXCEPT) -> (((() -> ReturnType@validate_kwarg)) -> (() -> ReturnType@validate_kwarg))" is not assignable to type "() -> ReturnType@process_kwarg" 

82# Extra parameter "kwarg_name" 

83# Extra parameter "validator" (reportArgumentType) 

84@process_kwarg("action", ErrorMode.from_any) # pyright: ignore[reportArgumentType] 

85def validate_kwarg( 

86 kwarg_name: str, 

87 validator: Callable[[T_kwarg], bool], 

88 description: str | None = None, 

89 action: ErrorMode = ErrorMode.EXCEPT, 

90) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 

91 """Decorator that validates a specific keyword argument. 

92 

93 # Parameters: 

94 - `kwarg_name : str` 

95 The name of the keyword argument to validate. 

96 - `validator : Callable[[Any], bool]` 

97 A callable that returns True if the keyword argument is valid. 

98 - `description : str | None` 

99 A message template if validation fails. 

100 - `action : str` 

101 Either `"raise"` (default) or `"warn"`. 

102 

103 # Returns: 

104 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 

105 A decorator that validates the keyword argument. 

106 

107 # Modifies: 

108 - If validation fails and `action=="warn"`, emits a warning. 

109 Otherwise, raises a ValueError. 

110 

111 # Usage: 

112 

113 ```python 

114 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}") 

115 def my_func(x: int) -> int: 

116 return x 

117 

118 assert my_func(x=1) == 1 

119 ``` 

120 

121 # Raises: 

122 - `ValueError` if validation fails and `action == "raise"`. 

123 """ 

124 

125 def decorator( 

126 func: Callable[FuncParams, ReturnType], 

127 ) -> Callable[FuncParams, ReturnType]: 

128 @functools.wraps(func) 

129 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 

130 if kwarg_name in kwargs: 

131 value: Any = kwargs[kwarg_name] 

132 if not validator(value): # ty: ignore[invalid-argument-type] 

133 msg: str = ( 

134 description.format(kwarg_name=kwarg_name, value=value) 

135 if description 

136 else f"Validation failed for keyword '{kwarg_name}' with value {value}" 

137 ) 

138 if action == "warn": 

139 warnings.warn(msg, UserWarning) 

140 else: 

141 raise ValueError(msg) 

142 return func(*args, **kwargs) 

143 

144 return cast(Callable[FuncParams, ReturnType], wrapper) 

145 

146 return decorator 

147 

148 

149def replace_kwarg( 

150 kwarg_name: str, 

151 check: Callable[[T_kwarg], bool], 

152 replacement_value: T_kwarg, 

153 replace_if_missing: bool = False, 

154) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 

155 """Decorator that replaces a specific keyword argument value by identity comparison. 

156 

157 # Parameters: 

158 - `kwarg_name : str` 

159 The name of the keyword argument to replace. 

160 - `check : Callable[[T_kwarg], bool]` 

161 A callable that returns True if the keyword argument should be replaced. 

162 - `replacement_value : T_kwarg` 

163 The value to replace with. 

164 - `replace_if_missing : bool` 

165 If True, replaces the keyword argument even if it's missing. 

166 

167 # Returns: 

168 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 

169 A decorator that replaces the keyword argument value. 

170 

171 # Modifies: 

172 - Updates `kwargs[kwarg_name]` if its value is `default_value`. 

173 

174 # Usage: 

175 

176 ```python 

177 @replace_kwarg("x", None, "default_string") 

178 def my_func(*, x: str | None = None) -> str: 

179 return x 

180 

181 assert my_func(x=None) == "default_string" 

182 ``` 

183 """ 

184 

185 def decorator( 

186 func: Callable[FuncParams, ReturnType], 

187 ) -> Callable[FuncParams, ReturnType]: 

188 @functools.wraps(func) 

189 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 

190 if kwarg_name in kwargs: 

191 # TODO: no way to type hint this, I think 

192 if check(kwargs[kwarg_name]): # type: ignore[arg-type] 

193 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] 

194 elif replace_if_missing and kwarg_name not in kwargs: 

195 kwargs[kwarg_name] = replacement_value # ty: ignore[invalid-assignment] 

196 return func(*args, **kwargs) 

197 

198 return cast(Callable[FuncParams, ReturnType], wrapper) 

199 

200 return decorator 

201 

202 

203def is_none(value: Any) -> bool: 

204 return value is None 

205 

206 

207def always_true(value: Any) -> bool: 

208 return True 

209 

210 

211def always_false(value: Any) -> bool: 

212 return False 

213 

214 

215def format_docstring( 

216 **fmt_kwargs: Any, 

217) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 

218 """Decorator that formats a function's docstring with the provided keyword arguments.""" 

219 

220 def decorator( 

221 func: Callable[FuncParams, ReturnType], 

222 ) -> Callable[FuncParams, ReturnType]: 

223 if func.__doc__ is not None: 

224 func.__doc__ = func.__doc__.format(**fmt_kwargs) 

225 return func 

226 

227 return decorator 

228 

229 

230# TODO: no way to make the type system understand this afaik 

231LambdaArgs = TypeVarTuple("LambdaArgs") 

232LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...]) 

233 

234 

235def typed_lambda( # pyright: ignore[reportUnknownParameterType] 

236 fn: Callable[[Unpack[LambdaArgs]], ReturnType], 

237 in_types: LambdaArgsTypes, # pyright: ignore[reportInvalidTypeVarUse] 

238 out_type: type[ReturnType], 

239) -> Callable[[Unpack[LambdaArgs]], ReturnType]: 

240 """Wraps a lambda function with type hints. 

241 

242 # Parameters: 

243 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]` 

244 The lambda function to wrap. 

245 - `in_types : tuple[type, ...]` 

246 Tuple of input types. 

247 - `out_type : type[ReturnType]` 

248 The output type. 

249 

250 # Returns: 

251 - `Callable[..., ReturnType]` 

252 A new function with annotations matching the given signature. 

253 

254 # Usage: 

255 

256 ```python 

257 add = typed_lambda(lambda x, y: x + y, (int, int), int) 

258 assert add(1, 2) == 3 

259 ``` 

260 

261 # Raises: 

262 - `ValueError` if the number of input types doesn't match the lambda's parameters. 

263 """ 

264 # it will just error here if fn.__code__ doesn't exist 

265 code: CodeType = fn.__code__ # type: ignore[unresolved-attribute] 

266 n_params: int = code.co_argcount 

267 

268 if len(in_types) != n_params: 

269 raise ValueError( 

270 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})" 

271 ) 

272 

273 param_names: tuple[str, ...] = code.co_varnames[:n_params] 

274 annotations: dict[str, type] = { # type: ignore[var-annotated] 

275 name: typ 

276 for name, typ in zip(param_names, in_types) # type: ignore[arg-type] 

277 } 

278 annotations["return"] = out_type 

279 

280 @functools.wraps(fn) 

281 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: # pyright: ignore[reportUnknownParameterType] 

282 return fn(*args) 

283 

284 wrapped.__annotations__ = annotations 

285 return wrapped