Coverage for muutils/validate_type.py: 74%
90 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""experimental utility for validating types in python, see `validate_type`"""
3from __future__ import annotations
5from inspect import signature, unwrap
6import types
7import typing
8import functools
10# this is also for python <3.10 compatibility
11_GenericAliasTypeNames: typing.List[str] = [
12 "GenericAlias",
13 "_GenericAlias",
14 "_UnionGenericAlias",
15 "_BaseGenericAlias",
16]
18_GenericAliasTypesList: list = [
19 getattr(typing, name, None) for name in _GenericAliasTypeNames
20]
22GenericAliasTypes: tuple = tuple([t for t in _GenericAliasTypesList if t is not None])
25class IncorrectTypeException(TypeError):
26 pass
29class TypeHintNotImplementedError(NotImplementedError):
30 pass
33class InvalidGenericAliasError(TypeError):
34 pass
37def _return_validation_except(
38 return_val: bool, value: typing.Any, expected_type: typing.Any
39) -> bool:
40 if return_val:
41 return True
42 else:
43 raise IncorrectTypeException(
44 f"Expected {expected_type = } for {value = }",
45 f"{type(value) = }",
46 f"{type(value).__mro__ = }",
47 f"{typing.get_origin(expected_type) = }",
48 f"{typing.get_args(expected_type) = }",
49 "\ndo --tb=long in pytest to see full trace",
50 )
51 return False
54def _return_validation_bool(return_val: bool) -> bool:
55 return return_val
58def validate_type(
59 value: typing.Any, expected_type: typing.Any, do_except: bool = False
60) -> bool:
61 """Validate that a `value` is of the `expected_type`
63 # Parameters
64 - `value`: the value to check the type of
65 - `expected_type`: the type to check against. Not all types are supported
66 - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`)
67 (default: `False`)
69 # Returns
70 - `bool`: `True` if the value is of the expected type, `False` otherwise.
72 # Raises
73 - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True`
74 - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented
75 - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid
77 use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard
78 """
79 if expected_type is typing.Any:
80 return True
82 # set up the return function depending on `do_except`
83 _return_func: typing.Callable[[bool], bool] = (
84 # functools.partial doesn't hint the function signature
85 functools.partial( # type: ignore[assignment]
86 _return_validation_except, value=value, expected_type=expected_type
87 )
88 if do_except
89 else _return_validation_bool
90 )
92 # base type without args
93 if isinstance(expected_type, type):
94 try:
95 # if you use args on a type like `dict[str, int]`, this will fail
96 return _return_func(isinstance(value, expected_type))
97 except TypeError as e:
98 if isinstance(e, IncorrectTypeException):
99 raise e
101 origin: typing.Any = typing.get_origin(expected_type)
102 args: tuple = typing.get_args(expected_type)
104 # useful for debugging
105 # print(f"{value = }, {expected_type = }, {origin = }, {args = }")
106 UnionType = getattr(types, "UnionType", None)
108 if (origin is typing.Union) or ( # this works in python <3.10
109 False
110 if UnionType is None # return False if UnionType is not available
111 else origin is UnionType # return True if UnionType is available
112 ):
113 return _return_func(any(validate_type(value, arg) for arg in args))
115 # generic alias, more complicated
116 item_type: type
117 if isinstance(expected_type, GenericAliasTypes):
118 if origin is list:
119 # no args
120 if len(args) == 0:
121 return _return_func(isinstance(value, list))
122 # incorrect number of args
123 if len(args) != 1:
124 raise InvalidGenericAliasError(
125 f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }",
126 f"{GenericAliasTypes = }",
127 )
128 # check is list
129 if not isinstance(value, list):
130 return _return_func(False)
131 # check all items in list are of the correct type
132 item_type = args[0]
133 return all(validate_type(item, item_type) for item in value)
135 if origin is dict:
136 # no args
137 if len(args) == 0:
138 return _return_func(isinstance(value, dict))
139 # incorrect number of args
140 if len(args) != 2:
141 raise InvalidGenericAliasError(
142 f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }",
143 f"{GenericAliasTypes = }",
144 )
145 # check is dict
146 if not isinstance(value, dict):
147 return _return_func(False)
148 # check all items in dict are of the correct type
149 key_type: type = args[0]
150 value_type: type = args[1]
151 return _return_func(
152 all(
153 validate_type(key, key_type) and validate_type(val, value_type)
154 for key, val in value.items()
155 )
156 )
158 if origin is set:
159 # no args
160 if len(args) == 0:
161 return _return_func(isinstance(value, set))
162 # incorrect number of args
163 if len(args) != 1:
164 raise InvalidGenericAliasError(
165 f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }",
166 f"{GenericAliasTypes = }",
167 )
168 # check is set
169 if not isinstance(value, set):
170 return _return_func(False)
171 # check all items in set are of the correct type
172 item_type = args[0]
173 return _return_func(all(validate_type(item, item_type) for item in value))
175 if origin is tuple:
176 # no args
177 if len(args) == 0:
178 return _return_func(isinstance(value, tuple))
179 # check is tuple
180 if not isinstance(value, tuple):
181 return _return_func(False)
182 # check correct number of items in tuple
183 if len(value) != len(args):
184 return _return_func(False)
185 # check all items in tuple are of the correct type
186 return _return_func(
187 all(validate_type(item, arg) for item, arg in zip(value, args))
188 )
190 if origin is type:
191 # no args
192 if len(args) == 0:
193 return _return_func(isinstance(value, type))
194 # incorrect number of args
195 if len(args) != 1:
196 raise InvalidGenericAliasError(
197 f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }",
198 f"{GenericAliasTypes = }",
199 )
200 # check is type
201 item_type = args[0]
202 if item_type in value.__mro__:
203 return _return_func(True)
204 else:
205 return _return_func(False)
207 # TODO: Callables, etc.
209 raise TypeHintNotImplementedError(
210 f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }",
211 f"{origin = }, {args = }",
212 f"\n{GenericAliasTypes = }",
213 )
215 else:
216 raise TypeHintNotImplementedError(
217 f"Unsupported type hint {expected_type = } for {value = }",
218 f"{origin = }, {args = }",
219 f"\n{GenericAliasTypes = }",
220 )
223def get_fn_allowed_kwargs(fn: typing.Callable) -> typing.Set[str]:
224 """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined."""
225 try:
226 fn = unwrap(fn)
227 params = signature(fn).parameters
228 except ValueError as e:
229 raise ValueError(
230 f"Cannot retrieve signature for {fn.__name__ = } {fn = }: {str(e)}"
231 ) from e
233 return {
234 param.name
235 for param in params.values()
236 if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY)
237 }