Coverage for muutils/misc/sequence.py: 98%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from __future__ import annotations 

2 

3from typing import ( 

4 Iterable, 

5 Any, 

6 Generator, 

7 Callable, 

8 Union, 

9) 

10 

11import typing 

12from typing import ( 

13 Literal, 

14 Mapping, 

15) 

16 

17 

18WhenMissing = Literal["except", "skip", "include"] 

19 

20 

21def empty_sequence_if_attr_false( 

22 itr: Iterable[Any], 

23 attr_owner: Any, 

24 attr_name: str, 

25) -> Iterable[Any]: 

26 """Returns `itr` if `attr_owner` has the attribute `attr_name` and it boolean casts to `True`. Returns an empty sequence otherwise. 

27 

28 Particularly useful for optionally inserting delimiters into a sequence depending on an `TokenizerElement` attribute. 

29 

30 # Parameters: 

31 - `itr: Iterable[Any]` 

32 The iterable to return if the attribute is `True`. 

33 - `attr_owner: Any` 

34 The object to check for the attribute. 

35 - `attr_name: str` 

36 The name of the attribute to check. 

37 

38 # Returns: 

39 - `itr: Iterable` if `attr_owner` has the attribute `attr_name` and it boolean casts to `True`, otherwise an empty sequence. 

40 - `()` an empty sequence if the attribute is `False` or not present. 

41 """ 

42 return itr if bool(getattr(attr_owner, attr_name, False)) else () 

43 

44 

45def flatten(it: Iterable[Any], levels_to_flatten: int | None = None) -> Generator: 

46 """ 

47 Flattens an arbitrarily nested iterable. 

48 Flattens all iterable data types except for `str` and `bytes`. 

49 

50 # Returns 

51 Generator over the flattened sequence. 

52 

53 # Parameters 

54 - `it`: Any arbitrarily nested iterable. 

55 - `levels_to_flatten`: Number of levels to flatten by, starting at the outermost layer. If `None`, performs full flattening. 

56 """ 

57 for x in it: 

58 # TODO: swap type check with more general check for __iter__() or __next__() or whatever 

59 if ( 

60 hasattr(x, "__iter__") 

61 and not isinstance(x, (str, bytes)) 

62 and (levels_to_flatten is None or levels_to_flatten > 0) 

63 ): 

64 yield from flatten( 

65 x, None if levels_to_flatten is None else levels_to_flatten - 1 

66 ) 

67 else: 

68 yield x 

69 

70 

71# string-like operations on lists 

72# -------------------------------------------------------------------------------- 

73 

74 

75def list_split(lst: list, val: Any) -> list[list]: 

76 """split a list into sublists by `val`. similar to "a_b_c".split("_") 

77 

78 ```python 

79 >>> list_split([1,2,3,0,4,5,0,6], 0) 

80 [[1, 2, 3], [4, 5], [6]] 

81 >>> list_split([0,1,2,3], 0) 

82 [[], [1, 2, 3]] 

83 >>> list_split([1,2,3], 0) 

84 [[1, 2, 3]] 

85 >>> list_split([], 0) 

86 [[]] 

87 ``` 

88 

89 """ 

90 

91 if len(lst) == 0: 

92 return [[]] 

93 

94 output: list[list] = [ 

95 [], 

96 ] 

97 

98 for x in lst: 

99 if x == val: 

100 output.append([]) 

101 else: 

102 output[-1].append(x) 

103 return output 

104 

105 

106def list_join(lst: list, factory: Callable) -> list: 

107 """add a *new* instance of `factory()` between each element of `lst` 

108 

109 ```python 

110 >>> list_join([1,2,3], lambda : 0) 

111 [1,0,2,0,3] 

112 >>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1]) 

113 [1, 1600000000.0, 2, 1600000000.1, 3] 

114 ``` 

115 """ 

116 

117 if len(lst) == 0: 

118 return [] 

119 

120 output: list = [ 

121 lst[0], 

122 ] 

123 

124 for x in lst[1:]: 

125 output.append(factory()) 

126 output.append(x) 

127 

128 return output 

129 

130 

131# applying mappings 

132# -------------------------------------------------------------------------------- 

133 

134_AM_K = typing.TypeVar("_AM_K") 

135_AM_V = typing.TypeVar("_AM_V") 

136 

137 

138def apply_mapping( 

139 mapping: Mapping[_AM_K, _AM_V], 

140 iter: Iterable[_AM_K], 

141 when_missing: WhenMissing = "skip", 

142) -> list[Union[_AM_K, _AM_V]]: 

143 """Given an iterable and a mapping, apply the mapping to the iterable with certain options 

144 

145 Gotcha: if `when_missing` is invalid, this is totally fine until a missing key is actually encountered. 

146 

147 Note: you can use this with `muutils.kappa.Kappa` if you want to pass a function instead of a dict 

148 

149 # Parameters: 

150 - `mapping : Mapping[_AM_K, _AM_V]` 

151 must have `__contains__` and `__getitem__`, both of which take `_AM_K` and the latter returns `_AM_V` 

152 - `iter : Iterable[_AM_K]` 

153 the iterable to apply the mapping to 

154 - `when_missing : WhenMissing` 

155 what to do when a key is missing from the mapping -- this is what distinguishes this function from `map` 

156 you can choose from `"skip"`, `"include"` (without converting), and `"except"` 

157 (defaults to `"skip"`) 

158 

159 # Returns: 

160 return type is one of: 

161 - `list[_AM_V]` if `when_missing` is `"skip"` or `"except"` 

162 - `list[Union[_AM_K, _AM_V]]` if `when_missing` is `"include"` 

163 

164 # Raises: 

165 - `KeyError` : if the item is missing from the mapping and `when_missing` is `"except"` 

166 - `ValueError` : if `when_missing` is invalid 

167 """ 

168 output: list[Union[_AM_K, _AM_V]] = list() 

169 item: _AM_K 

170 for item in iter: 

171 if item in mapping: 

172 output.append(mapping[item]) 

173 continue 

174 if when_missing == "skip": 

175 continue 

176 elif when_missing == "include": 

177 output.append(item) 

178 elif when_missing == "except": 

179 raise KeyError(f"item {item} is missing from mapping {mapping}") 

180 else: 

181 raise ValueError( 

182 f"invalid value for {when_missing = }\n{item = }\n{mapping = }" 

183 ) 

184 return output 

185 

186 

187def apply_mapping_chain( 

188 mapping: Mapping[_AM_K, Iterable[_AM_V]], 

189 iter: Iterable[_AM_K], 

190 when_missing: WhenMissing = "skip", 

191) -> list[Union[_AM_K, _AM_V]]: 

192 """Given an iterable and a mapping, chain the mappings together 

193 

194 Gotcha: if `when_missing` is invalid, this is totally fine until a missing key is actually encountered. 

195 

196 Note: you can use this with `muutils.kappa.Kappa` if you want to pass a function instead of a dict 

197 

198 # Parameters: 

199 - `mapping : Mapping[_AM_K, Iterable[_AM_V]]` 

200 must have `__contains__` and `__getitem__`, both of which take `_AM_K` and the latter returns `Iterable[_AM_V]` 

201 - `iter : Iterable[_AM_K]` 

202 the iterable to apply the mapping to 

203 - `when_missing : WhenMissing` 

204 what to do when a key is missing from the mapping -- this is what distinguishes this function from `map` 

205 you can choose from `"skip"`, `"include"` (without converting), and `"except"` 

206 (defaults to `"skip"`) 

207 

208 # Returns: 

209 return type is one of: 

210 - `list[_AM_V]` if `when_missing` is `"skip"` or `"except"` 

211 - `list[Union[_AM_K, _AM_V]]` if `when_missing` is `"include"` 

212 

213 # Raises: 

214 - `KeyError` : if the item is missing from the mapping and `when_missing` is `"except"` 

215 - `ValueError` : if `when_missing` is invalid 

216 

217 """ 

218 output: list[Union[_AM_K, _AM_V]] = list() 

219 item: _AM_K 

220 for item in iter: 

221 if item in mapping: 

222 output.extend(mapping[item]) 

223 continue 

224 if when_missing == "skip": 

225 continue 

226 elif when_missing == "include": 

227 output.append(item) 

228 elif when_missing == "except": 

229 raise KeyError(f"item {item} is missing from mapping {mapping}") 

230 else: 

231 raise ValueError( 

232 f"invalid value for {when_missing = }\n{item = }\n{mapping = }" 

233 ) 

234 return output