Coverage for muutils / validate_type.py: 73%
95 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:25 -0700
1"""experimental utility for validating types in python, see `validate_type`"""
3from __future__ import annotations
5from inspect import signature, unwrap
6import types
7import typing
8import functools
9from typing import Any
11# this is also for python <3.10 compatibility
12_GenericAliasTypeNames: typing.List[str] = [
13 "GenericAlias",
14 "_GenericAlias",
15 "_UnionGenericAlias",
16 "_BaseGenericAlias",
17]
19_GenericAliasTypesList: list[Any] = [
20 getattr(typing, name, None) for name in _GenericAliasTypeNames
21]
23GenericAliasTypes: tuple[Any, ...] = tuple(
24 [t for t in _GenericAliasTypesList if t is not None]
25)
28class IncorrectTypeException(TypeError):
29 pass
32class TypeHintNotImplementedError(NotImplementedError):
33 pass
36class InvalidGenericAliasError(TypeError):
37 pass
40def _return_validation_except(
41 return_val: bool, value: typing.Any, expected_type: typing.Any
42) -> bool:
43 if return_val:
44 return True
45 else:
46 raise IncorrectTypeException(
47 f"Expected {expected_type = } for {value = }",
48 f"{type(value) = }",
49 f"{type(value).__mro__ = }",
50 f"{typing.get_origin(expected_type) = }",
51 f"{typing.get_args(expected_type) = }",
52 "\ndo --tb=long in pytest to see full trace",
53 )
54 return False
57def _return_validation_bool(return_val: bool) -> bool:
58 return return_val
61def validate_type(
62 value: typing.Any, expected_type: typing.Any, do_except: bool = False
63) -> bool:
64 """Validate that a `value` is of the `expected_type`
66 # Parameters
67 - `value`: the value to check the type of
68 - `expected_type`: the type to check against. Not all types are supported
69 - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`)
70 (default: `False`)
72 # Returns
73 - `bool`: `True` if the value is of the expected type, `False` otherwise.
75 # Raises
76 - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True`
77 - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented
78 - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid
80 use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard
81 """
82 if expected_type is typing.Any:
83 return True
85 # set up the return function depending on `do_except`
86 _return_func: typing.Callable[[bool], bool] = (
87 # functools.partial doesn't hint the function signature
88 functools.partial( # type: ignore[assignment]
89 _return_validation_except, value=value, expected_type=expected_type
90 )
91 if do_except
92 else _return_validation_bool
93 )
95 # handle None type (used in type hints like tuple[int, None])
96 if expected_type is None:
97 return _return_func(value is None)
99 # base type without args
100 if isinstance(expected_type, type):
101 try:
102 # if you use args on a type like `dict[str, int]`, this will fail
103 return _return_func(isinstance(value, expected_type))
104 except TypeError as e:
105 if isinstance(e, IncorrectTypeException):
106 raise e
108 origin: typing.Any = typing.get_origin(expected_type)
109 args: tuple[Any, ...] = typing.get_args(expected_type)
111 # useful for debugging
112 # print(f"{value = }, {expected_type = }, {origin = }, {args = }")
113 UnionType = getattr(types, "UnionType", None)
115 if (origin is typing.Union) or ( # this works in python <3.10
116 False
117 if UnionType is None # return False if UnionType is not available
118 else origin is UnionType # return True if UnionType is available
119 ):
120 return _return_func(any(validate_type(value, arg) for arg in args))
122 # generic alias, more complicated
123 item_type: type
124 if isinstance(expected_type, GenericAliasTypes):
125 if origin is list:
126 # no args
127 if len(args) == 0:
128 return _return_func(isinstance(value, list))
129 # incorrect number of args
130 if len(args) != 1:
131 raise InvalidGenericAliasError(
132 f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }",
133 f"{GenericAliasTypes = }",
134 )
135 # check is list
136 if not isinstance(value, list):
137 return _return_func(False)
138 # check all items in list are of the correct type
139 item_type = args[0]
140 return all(validate_type(item, item_type) for item in value)
142 if origin is dict:
143 # no args
144 if len(args) == 0:
145 return _return_func(isinstance(value, dict))
146 # incorrect number of args
147 if len(args) != 2:
148 raise InvalidGenericAliasError(
149 f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }",
150 f"{GenericAliasTypes = }",
151 )
152 # check is dict
153 if not isinstance(value, dict):
154 return _return_func(False)
155 # check all items in dict are of the correct type
156 key_type: type = args[0]
157 value_type: type = args[1]
158 return _return_func(
159 all(
160 validate_type(key, key_type) and validate_type(val, value_type)
161 for key, val in value.items()
162 )
163 )
165 if origin is set:
166 # no args
167 if len(args) == 0:
168 return _return_func(isinstance(value, set))
169 # incorrect number of args
170 if len(args) != 1:
171 raise InvalidGenericAliasError(
172 f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }",
173 f"{GenericAliasTypes = }",
174 )
175 # check is set
176 if not isinstance(value, set):
177 return _return_func(False)
178 # check all items in set are of the correct type
179 item_type = args[0]
180 return _return_func(all(validate_type(item, item_type) for item in value))
182 if origin is tuple:
183 # no args
184 if len(args) == 0:
185 return _return_func(isinstance(value, tuple))
186 # check is tuple
187 if not isinstance(value, tuple):
188 return _return_func(False)
189 # check correct number of items in tuple
190 if len(value) != len(args):
191 return _return_func(False)
192 # check all items in tuple are of the correct type
193 return _return_func(
194 all(validate_type(item, arg) for item, arg in zip(value, args))
195 )
197 if origin is type:
198 # no args
199 if len(args) == 0:
200 return _return_func(isinstance(value, type))
201 # incorrect number of args
202 if len(args) != 1:
203 raise InvalidGenericAliasError(
204 f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }",
205 f"{GenericAliasTypes = }",
206 )
207 # check is type
208 item_type = args[0]
209 if item_type in value.__mro__:
210 return _return_func(True)
211 else:
212 return _return_func(False)
214 # TODO: Callables, etc.
216 raise TypeHintNotImplementedError(
217 f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }",
218 f"{origin = }, {args = }",
219 f"\n{GenericAliasTypes = }",
220 )
222 else:
223 raise TypeHintNotImplementedError(
224 f"Unsupported type hint {expected_type = } for {value = }",
225 f"{origin = }, {args = }",
226 f"\n{GenericAliasTypes = }",
227 )
230def get_fn_allowed_kwargs(fn: typing.Callable[..., Any]) -> typing.Set[str]:
231 """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined."""
232 try:
233 fn = unwrap(fn)
234 params = signature(fn).parameters
235 except ValueError as e:
236 fn_name: str = getattr(fn, "__name__", str(fn))
237 err_msg = f"Cannot retrieve signature for {fn_name = } {fn = }: {str(e)}"
238 raise ValueError(err_msg) from e
240 return {
241 param.name
242 for param in params.values()
243 if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY)
244 }