Coverage for muutils / validate_type.py: 73%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:25 -0700

1"""experimental utility for validating types in python, see `validate_type`""" 

2 

3from __future__ import annotations 

4 

5from inspect import signature, unwrap 

6import types 

7import typing 

8import functools 

9from typing import Any 

10 

11# this is also for python <3.10 compatibility 

12_GenericAliasTypeNames: typing.List[str] = [ 

13 "GenericAlias", 

14 "_GenericAlias", 

15 "_UnionGenericAlias", 

16 "_BaseGenericAlias", 

17] 

18 

19_GenericAliasTypesList: list[Any] = [ 

20 getattr(typing, name, None) for name in _GenericAliasTypeNames 

21] 

22 

23GenericAliasTypes: tuple[Any, ...] = tuple( 

24 [t for t in _GenericAliasTypesList if t is not None] 

25) 

26 

27 

28class IncorrectTypeException(TypeError): 

29 pass 

30 

31 

32class TypeHintNotImplementedError(NotImplementedError): 

33 pass 

34 

35 

36class InvalidGenericAliasError(TypeError): 

37 pass 

38 

39 

40def _return_validation_except( 

41 return_val: bool, value: typing.Any, expected_type: typing.Any 

42) -> bool: 

43 if return_val: 

44 return True 

45 else: 

46 raise IncorrectTypeException( 

47 f"Expected {expected_type = } for {value = }", 

48 f"{type(value) = }", 

49 f"{type(value).__mro__ = }", 

50 f"{typing.get_origin(expected_type) = }", 

51 f"{typing.get_args(expected_type) = }", 

52 "\ndo --tb=long in pytest to see full trace", 

53 ) 

54 return False 

55 

56 

57def _return_validation_bool(return_val: bool) -> bool: 

58 return return_val 

59 

60 

61def validate_type( 

62 value: typing.Any, expected_type: typing.Any, do_except: bool = False 

63) -> bool: 

64 """Validate that a `value` is of the `expected_type` 

65 

66 # Parameters 

67 - `value`: the value to check the type of 

68 - `expected_type`: the type to check against. Not all types are supported 

69 - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`) 

70 (default: `False`) 

71 

72 # Returns 

73 - `bool`: `True` if the value is of the expected type, `False` otherwise. 

74 

75 # Raises 

76 - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True` 

77 - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented 

78 - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid 

79 

80 use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard 

81 """ 

82 if expected_type is typing.Any: 

83 return True 

84 

85 # set up the return function depending on `do_except` 

86 _return_func: typing.Callable[[bool], bool] = ( 

87 # functools.partial doesn't hint the function signature 

88 functools.partial( # type: ignore[assignment] 

89 _return_validation_except, value=value, expected_type=expected_type 

90 ) 

91 if do_except 

92 else _return_validation_bool 

93 ) 

94 

95 # handle None type (used in type hints like tuple[int, None]) 

96 if expected_type is None: 

97 return _return_func(value is None) 

98 

99 # base type without args 

100 if isinstance(expected_type, type): 

101 try: 

102 # if you use args on a type like `dict[str, int]`, this will fail 

103 return _return_func(isinstance(value, expected_type)) 

104 except TypeError as e: 

105 if isinstance(e, IncorrectTypeException): 

106 raise e 

107 

108 origin: typing.Any = typing.get_origin(expected_type) 

109 args: tuple[Any, ...] = typing.get_args(expected_type) 

110 

111 # useful for debugging 

112 # print(f"{value = }, {expected_type = }, {origin = }, {args = }") 

113 UnionType = getattr(types, "UnionType", None) 

114 

115 if (origin is typing.Union) or ( # this works in python <3.10 

116 False 

117 if UnionType is None # return False if UnionType is not available 

118 else origin is UnionType # return True if UnionType is available 

119 ): 

120 return _return_func(any(validate_type(value, arg) for arg in args)) 

121 

122 # generic alias, more complicated 

123 item_type: type 

124 if isinstance(expected_type, GenericAliasTypes): 

125 if origin is list: 

126 # no args 

127 if len(args) == 0: 

128 return _return_func(isinstance(value, list)) 

129 # incorrect number of args 

130 if len(args) != 1: 

131 raise InvalidGenericAliasError( 

132 f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }", 

133 f"{GenericAliasTypes = }", 

134 ) 

135 # check is list 

136 if not isinstance(value, list): 

137 return _return_func(False) 

138 # check all items in list are of the correct type 

139 item_type = args[0] 

140 return all(validate_type(item, item_type) for item in value) 

141 

142 if origin is dict: 

143 # no args 

144 if len(args) == 0: 

145 return _return_func(isinstance(value, dict)) 

146 # incorrect number of args 

147 if len(args) != 2: 

148 raise InvalidGenericAliasError( 

149 f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }", 

150 f"{GenericAliasTypes = }", 

151 ) 

152 # check is dict 

153 if not isinstance(value, dict): 

154 return _return_func(False) 

155 # check all items in dict are of the correct type 

156 key_type: type = args[0] 

157 value_type: type = args[1] 

158 return _return_func( 

159 all( 

160 validate_type(key, key_type) and validate_type(val, value_type) 

161 for key, val in value.items() 

162 ) 

163 ) 

164 

165 if origin is set: 

166 # no args 

167 if len(args) == 0: 

168 return _return_func(isinstance(value, set)) 

169 # incorrect number of args 

170 if len(args) != 1: 

171 raise InvalidGenericAliasError( 

172 f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }", 

173 f"{GenericAliasTypes = }", 

174 ) 

175 # check is set 

176 if not isinstance(value, set): 

177 return _return_func(False) 

178 # check all items in set are of the correct type 

179 item_type = args[0] 

180 return _return_func(all(validate_type(item, item_type) for item in value)) 

181 

182 if origin is tuple: 

183 # no args 

184 if len(args) == 0: 

185 return _return_func(isinstance(value, tuple)) 

186 # check is tuple 

187 if not isinstance(value, tuple): 

188 return _return_func(False) 

189 # check correct number of items in tuple 

190 if len(value) != len(args): 

191 return _return_func(False) 

192 # check all items in tuple are of the correct type 

193 return _return_func( 

194 all(validate_type(item, arg) for item, arg in zip(value, args)) 

195 ) 

196 

197 if origin is type: 

198 # no args 

199 if len(args) == 0: 

200 return _return_func(isinstance(value, type)) 

201 # incorrect number of args 

202 if len(args) != 1: 

203 raise InvalidGenericAliasError( 

204 f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }", 

205 f"{GenericAliasTypes = }", 

206 ) 

207 # check is type 

208 item_type = args[0] 

209 if item_type in value.__mro__: 

210 return _return_func(True) 

211 else: 

212 return _return_func(False) 

213 

214 # TODO: Callables, etc. 

215 

216 raise TypeHintNotImplementedError( 

217 f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", 

218 f"{origin = }, {args = }", 

219 f"\n{GenericAliasTypes = }", 

220 ) 

221 

222 else: 

223 raise TypeHintNotImplementedError( 

224 f"Unsupported type hint {expected_type = } for {value = }", 

225 f"{origin = }, {args = }", 

226 f"\n{GenericAliasTypes = }", 

227 ) 

228 

229 

230def get_fn_allowed_kwargs(fn: typing.Callable[..., Any]) -> typing.Set[str]: 

231 """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.""" 

232 try: 

233 fn = unwrap(fn) 

234 params = signature(fn).parameters 

235 except ValueError as e: 

236 fn_name: str = getattr(fn, "__name__", str(fn)) 

237 err_msg = f"Cannot retrieve signature for {fn_name = } {fn = }: {str(e)}" 

238 raise ValueError(err_msg) from e 

239 

240 return { 

241 param.name 

242 for param in params.values() 

243 if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY) 

244 }