Coverage for muutils/validate_type.py: 74%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""experimental utility for validating types in python, see `validate_type`""" 

2 

3from __future__ import annotations 

4 

5from inspect import signature, unwrap 

6import types 

7import typing 

8import functools 

9 

10# this is also for python <3.10 compatibility 

11_GenericAliasTypeNames: typing.List[str] = [ 

12 "GenericAlias", 

13 "_GenericAlias", 

14 "_UnionGenericAlias", 

15 "_BaseGenericAlias", 

16] 

17 

18_GenericAliasTypesList: list = [ 

19 getattr(typing, name, None) for name in _GenericAliasTypeNames 

20] 

21 

22GenericAliasTypes: tuple = tuple([t for t in _GenericAliasTypesList if t is not None]) 

23 

24 

25class IncorrectTypeException(TypeError): 

26 pass 

27 

28 

29class TypeHintNotImplementedError(NotImplementedError): 

30 pass 

31 

32 

33class InvalidGenericAliasError(TypeError): 

34 pass 

35 

36 

37def _return_validation_except( 

38 return_val: bool, value: typing.Any, expected_type: typing.Any 

39) -> bool: 

40 if return_val: 

41 return True 

42 else: 

43 raise IncorrectTypeException( 

44 f"Expected {expected_type = } for {value = }", 

45 f"{type(value) = }", 

46 f"{type(value).__mro__ = }", 

47 f"{typing.get_origin(expected_type) = }", 

48 f"{typing.get_args(expected_type) = }", 

49 "\ndo --tb=long in pytest to see full trace", 

50 ) 

51 return False 

52 

53 

54def _return_validation_bool(return_val: bool) -> bool: 

55 return return_val 

56 

57 

58def validate_type( 

59 value: typing.Any, expected_type: typing.Any, do_except: bool = False 

60) -> bool: 

61 """Validate that a `value` is of the `expected_type` 

62 

63 # Parameters 

64 - `value`: the value to check the type of 

65 - `expected_type`: the type to check against. Not all types are supported 

66 - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`) 

67 (default: `False`) 

68 

69 # Returns 

70 - `bool`: `True` if the value is of the expected type, `False` otherwise. 

71 

72 # Raises 

73 - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True` 

74 - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented 

75 - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid 

76 

77 use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard 

78 """ 

79 if expected_type is typing.Any: 

80 return True 

81 

82 # set up the return function depending on `do_except` 

83 _return_func: typing.Callable[[bool], bool] = ( 

84 # functools.partial doesn't hint the function signature 

85 functools.partial( # type: ignore[assignment] 

86 _return_validation_except, value=value, expected_type=expected_type 

87 ) 

88 if do_except 

89 else _return_validation_bool 

90 ) 

91 

92 # base type without args 

93 if isinstance(expected_type, type): 

94 try: 

95 # if you use args on a type like `dict[str, int]`, this will fail 

96 return _return_func(isinstance(value, expected_type)) 

97 except TypeError as e: 

98 if isinstance(e, IncorrectTypeException): 

99 raise e 

100 

101 origin: typing.Any = typing.get_origin(expected_type) 

102 args: tuple = typing.get_args(expected_type) 

103 

104 # useful for debugging 

105 # print(f"{value = }, {expected_type = }, {origin = }, {args = }") 

106 UnionType = getattr(types, "UnionType", None) 

107 

108 if (origin is typing.Union) or ( # this works in python <3.10 

109 False 

110 if UnionType is None # return False if UnionType is not available 

111 else origin is UnionType # return True if UnionType is available 

112 ): 

113 return _return_func(any(validate_type(value, arg) for arg in args)) 

114 

115 # generic alias, more complicated 

116 item_type: type 

117 if isinstance(expected_type, GenericAliasTypes): 

118 if origin is list: 

119 # no args 

120 if len(args) == 0: 

121 return _return_func(isinstance(value, list)) 

122 # incorrect number of args 

123 if len(args) != 1: 

124 raise InvalidGenericAliasError( 

125 f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }", 

126 f"{GenericAliasTypes = }", 

127 ) 

128 # check is list 

129 if not isinstance(value, list): 

130 return _return_func(False) 

131 # check all items in list are of the correct type 

132 item_type = args[0] 

133 return all(validate_type(item, item_type) for item in value) 

134 

135 if origin is dict: 

136 # no args 

137 if len(args) == 0: 

138 return _return_func(isinstance(value, dict)) 

139 # incorrect number of args 

140 if len(args) != 2: 

141 raise InvalidGenericAliasError( 

142 f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }", 

143 f"{GenericAliasTypes = }", 

144 ) 

145 # check is dict 

146 if not isinstance(value, dict): 

147 return _return_func(False) 

148 # check all items in dict are of the correct type 

149 key_type: type = args[0] 

150 value_type: type = args[1] 

151 return _return_func( 

152 all( 

153 validate_type(key, key_type) and validate_type(val, value_type) 

154 for key, val in value.items() 

155 ) 

156 ) 

157 

158 if origin is set: 

159 # no args 

160 if len(args) == 0: 

161 return _return_func(isinstance(value, set)) 

162 # incorrect number of args 

163 if len(args) != 1: 

164 raise InvalidGenericAliasError( 

165 f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }", 

166 f"{GenericAliasTypes = }", 

167 ) 

168 # check is set 

169 if not isinstance(value, set): 

170 return _return_func(False) 

171 # check all items in set are of the correct type 

172 item_type = args[0] 

173 return _return_func(all(validate_type(item, item_type) for item in value)) 

174 

175 if origin is tuple: 

176 # no args 

177 if len(args) == 0: 

178 return _return_func(isinstance(value, tuple)) 

179 # check is tuple 

180 if not isinstance(value, tuple): 

181 return _return_func(False) 

182 # check correct number of items in tuple 

183 if len(value) != len(args): 

184 return _return_func(False) 

185 # check all items in tuple are of the correct type 

186 return _return_func( 

187 all(validate_type(item, arg) for item, arg in zip(value, args)) 

188 ) 

189 

190 if origin is type: 

191 # no args 

192 if len(args) == 0: 

193 return _return_func(isinstance(value, type)) 

194 # incorrect number of args 

195 if len(args) != 1: 

196 raise InvalidGenericAliasError( 

197 f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }", 

198 f"{GenericAliasTypes = }", 

199 ) 

200 # check is type 

201 item_type = args[0] 

202 if item_type in value.__mro__: 

203 return _return_func(True) 

204 else: 

205 return _return_func(False) 

206 

207 # TODO: Callables, etc. 

208 

209 raise TypeHintNotImplementedError( 

210 f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", 

211 f"{origin = }, {args = }", 

212 f"\n{GenericAliasTypes = }", 

213 ) 

214 

215 else: 

216 raise TypeHintNotImplementedError( 

217 f"Unsupported type hint {expected_type = } for {value = }", 

218 f"{origin = }, {args = }", 

219 f"\n{GenericAliasTypes = }", 

220 ) 

221 

222 

223def get_fn_allowed_kwargs(fn: typing.Callable) -> typing.Set[str]: 

224 """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.""" 

225 try: 

226 fn = unwrap(fn) 

227 params = signature(fn).parameters 

228 except ValueError as e: 

229 raise ValueError( 

230 f"Cannot retrieve signature for {fn.__name__ = } {fn = }: {str(e)}" 

231 ) from e 

232 

233 return { 

234 param.name 

235 for param in params.values() 

236 if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY) 

237 }