Coverage for muutils / validate_type.py: 73%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 21:32 -0600

1"""experimental utility for validating types in python, see `validate_type`""" 

2 

3from __future__ import annotations 

4 

5from inspect import signature, unwrap 

6import types 

7import typing 

8import functools 

9from typing import Any 

10 

11# this is also for python <3.10 compatibility 

12_GenericAliasTypeNames: typing.List[str] = [ 

13 "GenericAlias", 

14 "_GenericAlias", 

15 "_UnionGenericAlias", 

16 "_BaseGenericAlias", 

17] 

18 

19_GenericAliasTypesList: list[Any] = [ 

20 getattr(typing, name, None) for name in _GenericAliasTypeNames 

21] 

22 

23GenericAliasTypes: tuple[Any, ...] = tuple( 

24 [t for t in _GenericAliasTypesList if t is not None] 

25) 

26 

27 

28class IncorrectTypeException(TypeError): 

29 pass 

30 

31 

32class TypeHintNotImplementedError(NotImplementedError): 

33 pass 

34 

35 

36class InvalidGenericAliasError(TypeError): 

37 pass 

38 

39 

40def _return_validation_except( 

41 return_val: bool, value: typing.Any, expected_type: typing.Any 

42) -> bool: 

43 if return_val: 

44 return True 

45 else: 

46 raise IncorrectTypeException( 

47 f"Expected {expected_type = } for {value = }", 

48 f"{type(value) = }", 

49 f"{type(value).__mro__ = }", 

50 f"{typing.get_origin(expected_type) = }", 

51 f"{typing.get_args(expected_type) = }", 

52 "\ndo --tb=long in pytest to see full trace", 

53 ) 

54 return False 

55 

56 

57def _return_validation_bool(return_val: bool) -> bool: 

58 return return_val 

59 

60 

61def validate_type( 

62 value: typing.Any, expected_type: typing.Any, do_except: bool = False 

63) -> bool: 

64 """Validate that a `value` is of the `expected_type` 

65 

66 # Parameters 

67 - `value`: the value to check the type of 

68 - `expected_type`: the type to check against. Not all types are supported 

69 - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`) 

70 (default: `False`) 

71 

72 # Returns 

73 - `bool`: `True` if the value is of the expected type, `False` otherwise. 

74 

75 # Raises 

76 - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True` 

77 - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented 

78 - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid 

79 

80 use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard 

81 """ 

82 if expected_type is typing.Any: 

83 return True 

84 

85 # set up the return function depending on `do_except` 

86 _return_func: typing.Callable[[bool], bool] = ( 

87 # functools.partial doesn't hint the function signature 

88 functools.partial( # type: ignore[assignment] 

89 _return_validation_except, value=value, expected_type=expected_type 

90 ) 

91 if do_except 

92 else _return_validation_bool 

93 ) 

94 

95 # handle None type (used in type hints like tuple[int, None]) 

96 if expected_type is None: 

97 return _return_func(value is None) 

98 

99 # base type without args 

100 if isinstance(expected_type, type): 

101 try: 

102 # if you use args on a type like `dict[str, int]`, this will fail 

103 return _return_func(isinstance(value, expected_type)) 

104 except TypeError as e: 

105 if isinstance(e, IncorrectTypeException): 

106 raise e 

107 

108 origin: typing.Any = typing.get_origin(expected_type) 

109 args: tuple[Any, ...] = typing.get_args(expected_type) 

110 

111 # useful for debugging 

112 # print(f"{value = }, {expected_type = }, {origin = }, {args = }") 

113 UnionType = getattr(types, "UnionType", None) 

114 

115 if (origin is typing.Union) or ( # this works in python <3.10 

116 False 

117 if UnionType is None # return False if UnionType is not available 

118 else origin is UnionType # return True if UnionType is available 

119 ): 

120 return _return_func(any(validate_type(value, arg) for arg in args)) 

121 

122 # generic alias, more complicated 

123 item_type: type 

124 if isinstance(expected_type, GenericAliasTypes): 

125 if origin is typing.Literal: 

126 return _return_func(value in args) 

127 

128 if origin is list: 

129 # no args 

130 if len(args) == 0: 

131 return _return_func(isinstance(value, list)) 

132 # incorrect number of args 

133 if len(args) != 1: 

134 raise InvalidGenericAliasError( 

135 f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }", 

136 f"{GenericAliasTypes = }", 

137 ) 

138 # check is list 

139 if not isinstance(value, list): 

140 return _return_func(False) 

141 # check all items in list are of the correct type 

142 item_type = args[0] 

143 return _return_func(all(validate_type(item, item_type) for item in value)) 

144 

145 if origin is dict: 

146 # no args 

147 if len(args) == 0: 

148 return _return_func(isinstance(value, dict)) 

149 # incorrect number of args 

150 if len(args) != 2: 

151 raise InvalidGenericAliasError( 

152 f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }", 

153 f"{GenericAliasTypes = }", 

154 ) 

155 # check is dict 

156 if not isinstance(value, dict): 

157 return _return_func(False) 

158 # check all items in dict are of the correct type 

159 key_type: type = args[0] 

160 value_type: type = args[1] 

161 return _return_func( 

162 all( 

163 validate_type(key, key_type) and validate_type(val, value_type) 

164 for key, val in value.items() 

165 ) 

166 ) 

167 

168 if origin is set: 

169 # no args 

170 if len(args) == 0: 

171 return _return_func(isinstance(value, set)) 

172 # incorrect number of args 

173 if len(args) != 1: 

174 raise InvalidGenericAliasError( 

175 f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }", 

176 f"{GenericAliasTypes = }", 

177 ) 

178 # check is set 

179 if not isinstance(value, set): 

180 return _return_func(False) 

181 # check all items in set are of the correct type 

182 item_type = args[0] 

183 return _return_func(all(validate_type(item, item_type) for item in value)) 

184 

185 if origin is tuple: 

186 # no args 

187 if len(args) == 0: 

188 return _return_func(isinstance(value, tuple)) 

189 # check is tuple 

190 if not isinstance(value, tuple): 

191 return _return_func(False) 

192 # check correct number of items in tuple 

193 if len(value) != len(args): 

194 return _return_func(False) 

195 # check all items in tuple are of the correct type 

196 return _return_func( 

197 all(validate_type(item, arg) for item, arg in zip(value, args)) 

198 ) 

199 

200 if origin is type: 

201 # no args 

202 if len(args) == 0: 

203 return _return_func(isinstance(value, type)) 

204 # incorrect number of args 

205 if len(args) != 1: 

206 raise InvalidGenericAliasError( 

207 f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }", 

208 f"{GenericAliasTypes = }", 

209 ) 

210 # check is type 

211 item_type = args[0] 

212 if item_type in value.__mro__: 

213 return _return_func(True) 

214 else: 

215 return _return_func(False) 

216 

217 # TODO: Callables, etc. 

218 

219 raise TypeHintNotImplementedError( 

220 f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", 

221 f"{origin = }, {args = }", 

222 f"\n{GenericAliasTypes = }", 

223 ) 

224 

225 else: 

226 raise TypeHintNotImplementedError( 

227 f"Unsupported type hint {expected_type = } for {value = }", 

228 f"{origin = }, {args = }", 

229 f"\n{GenericAliasTypes = }", 

230 ) 

231 

232 

233def get_fn_allowed_kwargs(fn: typing.Callable[..., Any]) -> typing.Set[str]: 

234 """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.""" 

235 try: 

236 fn = unwrap(fn) 

237 params = signature(fn).parameters 

238 except ValueError as e: 

239 fn_name: str = getattr(fn, "__name__", str(fn)) 

240 err_msg = f"Cannot retrieve signature for {fn_name = } {fn = }: {str(e)}" 

241 raise ValueError(err_msg) from e 

242 

243 return { 

244 param.name 

245 for param in params.values() 

246 if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY) 

247 }