Coverage for muutils / validate_type.py: 73%
97 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 21:32 -0600
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 21:32 -0600
1"""experimental utility for validating types in python, see `validate_type`"""
3from __future__ import annotations
5from inspect import signature, unwrap
6import types
7import typing
8import functools
9from typing import Any
11# this is also for python <3.10 compatibility
12_GenericAliasTypeNames: typing.List[str] = [
13 "GenericAlias",
14 "_GenericAlias",
15 "_UnionGenericAlias",
16 "_BaseGenericAlias",
17]
19_GenericAliasTypesList: list[Any] = [
20 getattr(typing, name, None) for name in _GenericAliasTypeNames
21]
23GenericAliasTypes: tuple[Any, ...] = tuple(
24 [t for t in _GenericAliasTypesList if t is not None]
25)
28class IncorrectTypeException(TypeError):
29 pass
32class TypeHintNotImplementedError(NotImplementedError):
33 pass
36class InvalidGenericAliasError(TypeError):
37 pass
40def _return_validation_except(
41 return_val: bool, value: typing.Any, expected_type: typing.Any
42) -> bool:
43 if return_val:
44 return True
45 else:
46 raise IncorrectTypeException(
47 f"Expected {expected_type = } for {value = }",
48 f"{type(value) = }",
49 f"{type(value).__mro__ = }",
50 f"{typing.get_origin(expected_type) = }",
51 f"{typing.get_args(expected_type) = }",
52 "\ndo --tb=long in pytest to see full trace",
53 )
54 return False
57def _return_validation_bool(return_val: bool) -> bool:
58 return return_val
61def validate_type(
62 value: typing.Any, expected_type: typing.Any, do_except: bool = False
63) -> bool:
64 """Validate that a `value` is of the `expected_type`
66 # Parameters
67 - `value`: the value to check the type of
68 - `expected_type`: the type to check against. Not all types are supported
69 - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`)
70 (default: `False`)
72 # Returns
73 - `bool`: `True` if the value is of the expected type, `False` otherwise.
75 # Raises
76 - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True`
77 - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented
78 - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid
80 use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard
81 """
82 if expected_type is typing.Any:
83 return True
85 # set up the return function depending on `do_except`
86 _return_func: typing.Callable[[bool], bool] = (
87 # functools.partial doesn't hint the function signature
88 functools.partial( # type: ignore[assignment]
89 _return_validation_except, value=value, expected_type=expected_type
90 )
91 if do_except
92 else _return_validation_bool
93 )
95 # handle None type (used in type hints like tuple[int, None])
96 if expected_type is None:
97 return _return_func(value is None)
99 # base type without args
100 if isinstance(expected_type, type):
101 try:
102 # if you use args on a type like `dict[str, int]`, this will fail
103 return _return_func(isinstance(value, expected_type))
104 except TypeError as e:
105 if isinstance(e, IncorrectTypeException):
106 raise e
108 origin: typing.Any = typing.get_origin(expected_type)
109 args: tuple[Any, ...] = typing.get_args(expected_type)
111 # useful for debugging
112 # print(f"{value = }, {expected_type = }, {origin = }, {args = }")
113 UnionType = getattr(types, "UnionType", None)
115 if (origin is typing.Union) or ( # this works in python <3.10
116 False
117 if UnionType is None # return False if UnionType is not available
118 else origin is UnionType # return True if UnionType is available
119 ):
120 return _return_func(any(validate_type(value, arg) for arg in args))
122 # generic alias, more complicated
123 item_type: type
124 if isinstance(expected_type, GenericAliasTypes):
125 if origin is typing.Literal:
126 return _return_func(value in args)
128 if origin is list:
129 # no args
130 if len(args) == 0:
131 return _return_func(isinstance(value, list))
132 # incorrect number of args
133 if len(args) != 1:
134 raise InvalidGenericAliasError(
135 f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }",
136 f"{GenericAliasTypes = }",
137 )
138 # check is list
139 if not isinstance(value, list):
140 return _return_func(False)
141 # check all items in list are of the correct type
142 item_type = args[0]
143 return _return_func(all(validate_type(item, item_type) for item in value))
145 if origin is dict:
146 # no args
147 if len(args) == 0:
148 return _return_func(isinstance(value, dict))
149 # incorrect number of args
150 if len(args) != 2:
151 raise InvalidGenericAliasError(
152 f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }",
153 f"{GenericAliasTypes = }",
154 )
155 # check is dict
156 if not isinstance(value, dict):
157 return _return_func(False)
158 # check all items in dict are of the correct type
159 key_type: type = args[0]
160 value_type: type = args[1]
161 return _return_func(
162 all(
163 validate_type(key, key_type) and validate_type(val, value_type)
164 for key, val in value.items()
165 )
166 )
168 if origin is set:
169 # no args
170 if len(args) == 0:
171 return _return_func(isinstance(value, set))
172 # incorrect number of args
173 if len(args) != 1:
174 raise InvalidGenericAliasError(
175 f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }",
176 f"{GenericAliasTypes = }",
177 )
178 # check is set
179 if not isinstance(value, set):
180 return _return_func(False)
181 # check all items in set are of the correct type
182 item_type = args[0]
183 return _return_func(all(validate_type(item, item_type) for item in value))
185 if origin is tuple:
186 # no args
187 if len(args) == 0:
188 return _return_func(isinstance(value, tuple))
189 # check is tuple
190 if not isinstance(value, tuple):
191 return _return_func(False)
192 # check correct number of items in tuple
193 if len(value) != len(args):
194 return _return_func(False)
195 # check all items in tuple are of the correct type
196 return _return_func(
197 all(validate_type(item, arg) for item, arg in zip(value, args))
198 )
200 if origin is type:
201 # no args
202 if len(args) == 0:
203 return _return_func(isinstance(value, type))
204 # incorrect number of args
205 if len(args) != 1:
206 raise InvalidGenericAliasError(
207 f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }",
208 f"{GenericAliasTypes = }",
209 )
210 # check is type
211 item_type = args[0]
212 if item_type in value.__mro__:
213 return _return_func(True)
214 else:
215 return _return_func(False)
217 # TODO: Callables, etc.
219 raise TypeHintNotImplementedError(
220 f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }",
221 f"{origin = }, {args = }",
222 f"\n{GenericAliasTypes = }",
223 )
225 else:
226 raise TypeHintNotImplementedError(
227 f"Unsupported type hint {expected_type = } for {value = }",
228 f"{origin = }, {args = }",
229 f"\n{GenericAliasTypes = }",
230 )
233def get_fn_allowed_kwargs(fn: typing.Callable[..., Any]) -> typing.Set[str]:
234 """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined."""
235 try:
236 fn = unwrap(fn)
237 params = signature(fn).parameters
238 except ValueError as e:
239 fn_name: str = getattr(fn, "__name__", str(fn))
240 err_msg = f"Cannot retrieve signature for {fn_name = } {fn = }: {str(e)}"
241 raise ValueError(err_msg) from e
243 return {
244 param.name
245 for param in params.values()
246 if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY)
247 }