Coverage for muutils / json_serialize / array.py: 60%

114 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1"""this utilities module handles serialization and loading of numpy and torch arrays as json 

2 

3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability. 

4- `array_b64_meta` is the most efficient, but is not human readable. 

5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ) 

6 

7""" 

8 

9from __future__ import annotations 

10 

11import base64 

12import typing 

13import warnings 

14from typing import ( 

15 TYPE_CHECKING, 

16 Any, 

17 Iterable, 

18 Literal, 

19 Optional, 

20 Sequence, 

21 TypedDict, 

22 Union, 

23 overload, 

24) 

25 

26try: 

27 import numpy as np 

28except ImportError as e: 

29 warnings.warn( 

30 f"numpy is not installed, array serialization will not work: \n{e}", 

31 ImportWarning, 

32 ) 

33 

34if TYPE_CHECKING: 

35 import numpy as np 

36 import torch 

37 from muutils.json_serialize.json_serialize import JsonSerializer 

38 

39from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage] 

40 

41# TYPING: pyright complains way too much here 

42# pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false 

43 

44# Recursive type for nested numeric lists (output of arr.tolist()) 

45NumericList = typing.Union[ 

46 typing.List[typing.Union[int, float, bool]], 

47 typing.List["NumericList"], 

48] 

49 

50ArrayMode = Literal[ 

51 "list", 

52 "array_list_meta", 

53 "array_hex_meta", 

54 "array_b64_meta", 

55 "external", 

56 "zero_dim", 

57] 

58 

59# Modes that produce SerializedArrayWithMeta (dict with metadata) 

60ArrayModeWithMeta = Literal[ 

61 "array_list_meta", 

62 "array_hex_meta", 

63 "array_b64_meta", 

64 "zero_dim", 

65 "external", 

66] 

67 

68 

69def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: ignore[reportAny] 

70 """get the number of elements in an array""" 

71 if isinstance(arr, np.ndarray): 

72 return arr.size 

73 elif str(type(arr)) == "<class 'torch.Tensor'>": # pyright: ignore[reportUnknownArgumentType, reportAny] 

74 assert hasattr(arr, "nelement"), ( 

75 "torch Tensor does not have nelement() method? this should not happen" 

76 ) # pyright: ignore[reportAny] 

77 return arr.nelement() # pyright: ignore[reportAny] 

78 else: 

79 raise TypeError(f"invalid type: {type(arr)}") # pyright: ignore[reportAny] 

80 

81 

82class ArrayMetadata(TypedDict): 

83 """Metadata for a numpy/torch array""" 

84 

85 shape: list[int] 

86 dtype: str 

87 n_elements: int 

88 

89 

90class SerializedArrayWithMeta(TypedDict): 

91 """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)""" 

92 

93 __muutils_format__: str 

94 data: typing.Union[ 

95 NumericList, str, int, float, bool 

96 ] # list, hex str, b64 str, or scalar for zero_dim 

97 shape: list[int] 

98 dtype: str 

99 n_elements: int 

100 

101 

102def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny] 

103 """get metadata for a numpy array""" 

104 return { 

105 "shape": list(arr.shape), # pyright: ignore[reportAny] 

106 "dtype": ( 

107 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) # pyright: ignore[reportAny] 

108 ), 

109 "n_elements": array_n_elements(arr), 

110 } 

111 

112 

113@overload 

114def serialize_array( 

115 jser: "JsonSerializer", 

116 arr: "Union[np.ndarray, torch.Tensor]", 

117 path: str | Sequence[str | int], 

118 array_mode: Literal["list"], 

119) -> NumericList: ... 

120@overload 

121def serialize_array( 

122 jser: "JsonSerializer", 

123 arr: "Union[np.ndarray, torch.Tensor]", 

124 path: str | Sequence[str | int], 

125 array_mode: ArrayModeWithMeta, 

126) -> SerializedArrayWithMeta: ... 

127@overload 

128def serialize_array( 

129 jser: "JsonSerializer", 

130 arr: "Union[np.ndarray, torch.Tensor]", 

131 path: str | Sequence[str | int], 

132 array_mode: None = None, 

133) -> SerializedArrayWithMeta | NumericList: ... 

134def serialize_array( 

135 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 

136 arr: "Union[np.ndarray, torch.Tensor]", 

137 path: str | Sequence[str | int], # pyright: ignore[reportUnusedParameter] 

138 array_mode: ArrayMode | None = None, 

139) -> SerializedArrayWithMeta | NumericList: 

140 """serialize a numpy or pytorch array in one of several modes 

141 

142 if the object is zero-dimensional, simply get the unique item 

143 

144 `array_mode: ArrayMode` can be one of: 

145 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`) 

146 - `array_list_meta`: serialize dict with metadata, actual list under the key `data` 

147 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data` 

148 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data` 

149 

150 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is: 

151 ``` 

152 { 

153 _FORMAT_KEY: <array_list_meta|array_hex_meta>, 

154 "shape": arr.shape, 

155 "dtype": str(arr.dtype), 

156 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>, 

157 } 

158 ``` 

159 

160 # Parameters: 

161 - `arr : Any` array to serialize 

162 - `array_mode : ArrayMode` mode in which to serialize the array 

163 (defaults to `None` and inheriting from `jser: JsonSerializer`) 

164 

165 # Returns: 

166 - `JSONitem` 

167 json serialized array 

168 

169 # Raises: 

170 - `KeyError` : if the array mode is not valid 

171 """ 

172 

173 if array_mode is None: 

174 array_mode = jser.array_mode 

175 

176 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" 

177 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) # pyright: ignore[reportUnnecessaryIsInstance] 

178 

179 # Handle list mode first (no metadata needed) 

180 if array_mode == "list": 

181 return arr_np.tolist() # pyright: ignore[reportAny] 

182 

183 # For all other modes, compute metadata once 

184 metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np) 

185 

186 # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe? 

187 

188 # handle zero-dimensional arrays 

189 if len(arr.shape) == 0: 

190 return SerializedArrayWithMeta( 

191 __muutils_format__=f"{arr_type}:zero_dim", 

192 data=arr.item(), # pyright: ignore[reportAny] 

193 shape=metadata["shape"], 

194 dtype=metadata["dtype"], 

195 n_elements=metadata["n_elements"], 

196 ) 

197 

198 # Handle the metadata modes 

199 if array_mode == "array_list_meta": 

200 return SerializedArrayWithMeta( 

201 __muutils_format__=f"{arr_type}:array_list_meta", 

202 data=arr_np.tolist(), # pyright: ignore[reportAny] 

203 shape=metadata["shape"], 

204 dtype=metadata["dtype"], 

205 n_elements=metadata["n_elements"], 

206 ) 

207 elif array_mode == "array_hex_meta": 

208 return SerializedArrayWithMeta( 

209 __muutils_format__=f"{arr_type}:array_hex_meta", 

210 data=arr_np.tobytes().hex(), 

211 shape=metadata["shape"], 

212 dtype=metadata["dtype"], 

213 n_elements=metadata["n_elements"], 

214 ) 

215 elif array_mode == "array_b64_meta": 

216 return SerializedArrayWithMeta( 

217 __muutils_format__=f"{arr_type}:array_b64_meta", 

218 data=base64.b64encode(arr_np.tobytes()).decode(), 

219 shape=metadata["shape"], 

220 dtype=metadata["dtype"], 

221 n_elements=metadata["n_elements"], 

222 ) 

223 else: 

224 raise KeyError(f"invalid array_mode: {array_mode}") 

225 

226 

227@overload 

228def infer_array_mode( 

229 arr: SerializedArrayWithMeta, 

230) -> ArrayModeWithMeta: ... 

231@overload 

232def infer_array_mode(arr: NumericList) -> Literal["list"]: ... 

233def infer_array_mode( 

234 arr: Union[SerializedArrayWithMeta, NumericList], 

235) -> ArrayMode: 

236 """given a serialized array, infer the mode 

237 

238 assumes the array was serialized via `serialize_array()` 

239 """ 

240 return_mode: ArrayMode 

241 if isinstance(arr, typing.Mapping): 

242 # _FORMAT_KEY always maps to a string 

243 fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore 

244 if fmt.endswith(":array_list_meta"): 

245 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 

246 if not isinstance(arr_data, Iterable): 

247 raise ValueError(f"invalid list format: {type(arr_data) = }\t{arr}") 

248 return_mode = "array_list_meta" 

249 elif fmt.endswith(":array_hex_meta"): 

250 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 

251 if not isinstance(arr_data, str): 

252 raise ValueError(f"invalid hex format: {type(arr_data) = }\t{arr}") 

253 return_mode = "array_hex_meta" 

254 elif fmt.endswith(":array_b64_meta"): 

255 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 

256 if not isinstance(arr_data, str): 

257 raise ValueError(f"invalid b64 format: {type(arr_data) = }\t{arr}") 

258 return_mode = "array_b64_meta" 

259 elif fmt.endswith(":external"): 

260 return_mode = "external" 

261 elif fmt.endswith(":zero_dim"): 

262 return_mode = "zero_dim" 

263 else: 

264 raise ValueError(f"invalid format: {arr}") 

265 elif isinstance(arr, list): # pyright: ignore[reportUnnecessaryIsInstance] 

266 return_mode = "list" 

267 else: 

268 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") # pyright: ignore[reportUnreachable] 

269 

270 return return_mode 

271 

272 

273@overload 

274def load_array( 

275 arr: SerializedArrayWithMeta, 

276 array_mode: Optional[ArrayModeWithMeta] = None, 

277) -> np.ndarray: ... 

278@overload 

279def load_array( 

280 arr: NumericList, 

281 array_mode: Optional[Literal["list"]] = None, 

282) -> np.ndarray: ... 

283@overload 

284def load_array( 

285 arr: np.ndarray, 

286 array_mode: None = None, 

287) -> np.ndarray: ... 

288def load_array( 

289 arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList], 

290 array_mode: Optional[ArrayMode] = None, 

291) -> np.ndarray: 

292 """load a json-serialized array, infer the mode if not specified""" 

293 # return arr if its already a numpy array 

294 if isinstance(arr, np.ndarray): 

295 assert array_mode is None, ( 

296 "array_mode should not be specified when loading a numpy array, since that is a no-op" 

297 ) 

298 return arr 

299 

300 # try to infer the array_mode 

301 array_mode_inferred: ArrayMode = infer_array_mode(arr) 

302 if array_mode is None: 

303 array_mode = array_mode_inferred 

304 elif array_mode != array_mode_inferred: 

305 warnings.warn( 

306 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 

307 ) 

308 

309 # actually load the array 

310 if array_mode == "array_list_meta": 

311 assert isinstance(arr, typing.Mapping), ( 

312 f"invalid list format: {type(arr) = }\n{arr = }" 

313 ) 

314 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore 

315 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 

316 raise ValueError(f"invalid shape: {arr}") 

317 return data 

318 

319 elif array_mode == "array_hex_meta": 

320 assert isinstance(arr, typing.Mapping), ( 

321 f"invalid list format: {type(arr) = }\n{arr = }" 

322 ) 

323 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore 

324 return data.reshape(arr["shape"]) # type: ignore 

325 

326 elif array_mode == "array_b64_meta": 

327 assert isinstance(arr, typing.Mapping), ( 

328 f"invalid list format: {type(arr) = }\n{arr = }" 

329 ) 

330 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore 

331 return data.reshape(arr["shape"]) # type: ignore 

332 

333 elif array_mode == "list": 

334 assert isinstance(arr, typing.Sequence), ( 

335 f"invalid list format: {type(arr) = }\n{arr = }" 

336 ) 

337 return np.array(arr) # type: ignore 

338 elif array_mode == "external": 

339 assert isinstance(arr, typing.Mapping) 

340 if "data" not in arr: 

341 raise KeyError( # pyright: ignore[reportUnreachable] 

342 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 

343 ) 

344 # we can ignore here since we assume ZANJ has taken care of it 

345 return arr["data"] # type: ignore[return-value] # pyright: ignore[reportReturnType] 

346 elif array_mode == "zero_dim": 

347 assert isinstance(arr, typing.Mapping) 

348 data = np.array(arr["data"]) # ty: ignore[invalid-argument-type] 

349 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 

350 raise ValueError(f"invalid shape: {arr}") 

351 return data 

352 else: 

353 raise ValueError(f"invalid array_mode: {array_mode}") # pyright: ignore[reportUnreachable]