Coverage for muutils/json_serialize/array.py: 62%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""this utilities module handles serialization and loading of numpy and torch arrays as json 

2 

3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability. 

4- `array_b64_meta` is the most efficient, but is not human readable. 

5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ) 

6 

7""" 

8 

9from __future__ import annotations 

10 

11import base64 

12import typing 

13import warnings 

14from typing import Any, Iterable, Literal, Optional, Sequence 

15 

16try: 

17 import numpy as np 

18except ImportError as e: 

19 warnings.warn( 

20 f"numpy is not installed, array serialization will not work: \n{e}", 

21 ImportWarning, 

22 ) 

23 

24from muutils.json_serialize.util import _FORMAT_KEY, JSONitem 

25 

26# pylint: disable=unused-argument 

27 

28ArrayMode = Literal[ 

29 "list", 

30 "array_list_meta", 

31 "array_hex_meta", 

32 "array_b64_meta", 

33 "external", 

34 "zero_dim", 

35] 

36 

37 

38def array_n_elements(arr) -> int: # type: ignore[name-defined] 

39 """get the number of elements in an array""" 

40 if isinstance(arr, np.ndarray): 

41 return arr.size 

42 elif str(type(arr)) == "<class 'torch.Tensor'>": 

43 return arr.nelement() 

44 else: 

45 raise TypeError(f"invalid type: {type(arr)}") 

46 

47 

48def arr_metadata(arr) -> dict[str, list[int] | str | int]: 

49 """get metadata for a numpy array""" 

50 return { 

51 "shape": list(arr.shape), 

52 "dtype": ( 

53 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) 

54 ), 

55 "n_elements": array_n_elements(arr), 

56 } 

57 

58 

59def serialize_array( 

60 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 

61 arr: np.ndarray, 

62 path: str | Sequence[str | int], 

63 array_mode: ArrayMode | None = None, 

64) -> JSONitem: 

65 """serialize a numpy or pytorch array in one of several modes 

66 

67 if the object is zero-dimensional, simply get the unique item 

68 

69 `array_mode: ArrayMode` can be one of: 

70 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`) 

71 - `array_list_meta`: serialize dict with metadata, actual list under the key `data` 

72 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data` 

73 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data` 

74 

75 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is: 

76 ``` 

77 { 

78 _FORMAT_KEY: <array_list_meta|array_hex_meta>, 

79 "shape": arr.shape, 

80 "dtype": str(arr.dtype), 

81 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>, 

82 } 

83 ``` 

84 

85 # Parameters: 

86 - `arr : Any` array to serialize 

87 - `array_mode : ArrayMode` mode in which to serialize the array 

88 (defaults to `None` and inheriting from `jser: JsonSerializer`) 

89 

90 # Returns: 

91 - `JSONitem` 

92 json serialized array 

93 

94 # Raises: 

95 - `KeyError` : if the array mode is not valid 

96 """ 

97 

98 if array_mode is None: 

99 array_mode = jser.array_mode 

100 

101 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" 

102 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) 

103 

104 # handle zero-dimensional arrays 

105 if len(arr.shape) == 0: 

106 return { 

107 _FORMAT_KEY: f"{arr_type}:zero_dim", 

108 "data": arr.item(), 

109 **arr_metadata(arr), 

110 } 

111 

112 if array_mode == "array_list_meta": 

113 return { 

114 _FORMAT_KEY: f"{arr_type}:array_list_meta", 

115 "data": arr_np.tolist(), 

116 **arr_metadata(arr_np), 

117 } 

118 elif array_mode == "list": 

119 return arr_np.tolist() 

120 elif array_mode == "array_hex_meta": 

121 return { 

122 _FORMAT_KEY: f"{arr_type}:array_hex_meta", 

123 "data": arr_np.tobytes().hex(), 

124 **arr_metadata(arr_np), 

125 } 

126 elif array_mode == "array_b64_meta": 

127 return { 

128 _FORMAT_KEY: f"{arr_type}:array_b64_meta", 

129 "data": base64.b64encode(arr_np.tobytes()).decode(), 

130 **arr_metadata(arr_np), 

131 } 

132 else: 

133 raise KeyError(f"invalid array_mode: {array_mode}") 

134 

135 

136def infer_array_mode(arr: JSONitem) -> ArrayMode: 

137 """given a serialized array, infer the mode 

138 

139 assumes the array was serialized via `serialize_array()` 

140 """ 

141 if isinstance(arr, typing.Mapping): 

142 fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore 

143 if fmt.endswith(":array_list_meta"): 

144 if not isinstance(arr["data"], Iterable): 

145 raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}") 

146 return "array_list_meta" 

147 elif fmt.endswith(":array_hex_meta"): 

148 if not isinstance(arr["data"], str): 

149 raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}") 

150 return "array_hex_meta" 

151 elif fmt.endswith(":array_b64_meta"): 

152 if not isinstance(arr["data"], str): 

153 raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}") 

154 return "array_b64_meta" 

155 elif fmt.endswith(":external"): 

156 return "external" 

157 elif fmt.endswith(":zero_dim"): 

158 return "zero_dim" 

159 else: 

160 raise ValueError(f"invalid format: {arr}") 

161 elif isinstance(arr, list): 

162 return "list" 

163 else: 

164 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") 

165 

166 

167def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: 

168 """load a json-serialized array, infer the mode if not specified""" 

169 # return arr if its already a numpy array 

170 if isinstance(arr, np.ndarray) and array_mode is None: 

171 return arr 

172 

173 # try to infer the array_mode 

174 array_mode_inferred: ArrayMode = infer_array_mode(arr) 

175 if array_mode is None: 

176 array_mode = array_mode_inferred 

177 elif array_mode != array_mode_inferred: 

178 warnings.warn( 

179 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 

180 ) 

181 

182 # actually load the array 

183 if array_mode == "array_list_meta": 

184 assert isinstance( 

185 arr, typing.Mapping 

186 ), f"invalid list format: {type(arr) = }\n{arr = }" 

187 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore 

188 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 

189 raise ValueError(f"invalid shape: {arr}") 

190 return data 

191 

192 elif array_mode == "array_hex_meta": 

193 assert isinstance( 

194 arr, typing.Mapping 

195 ), f"invalid list format: {type(arr) = }\n{arr = }" 

196 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore 

197 return data.reshape(arr["shape"]) # type: ignore 

198 

199 elif array_mode == "array_b64_meta": 

200 assert isinstance( 

201 arr, typing.Mapping 

202 ), f"invalid list format: {type(arr) = }\n{arr = }" 

203 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore 

204 return data.reshape(arr["shape"]) # type: ignore 

205 

206 elif array_mode == "list": 

207 assert isinstance( 

208 arr, typing.Sequence 

209 ), f"invalid list format: {type(arr) = }\n{arr = }" 

210 return np.array(arr) # type: ignore 

211 elif array_mode == "external": 

212 # assume ZANJ has taken care of it 

213 assert isinstance(arr, typing.Mapping) 

214 if "data" not in arr: 

215 raise KeyError( 

216 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 

217 ) 

218 return arr["data"] 

219 elif array_mode == "zero_dim": 

220 assert isinstance(arr, typing.Mapping) 

221 data = np.array(arr["data"]) 

222 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 

223 raise ValueError(f"invalid shape: {arr}") 

224 return data 

225 else: 

226 raise ValueError(f"invalid array_mode: {array_mode}")