Coverage for muutils / json_serialize / array.py: 60%
114 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1"""this utilities module handles serialization and loading of numpy and torch arrays as json
3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
4- `array_b64_meta` is the most efficient, but is not human readable.
5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ)
7"""
9from __future__ import annotations
11import base64
12import typing
13import warnings
14from typing import (
15 TYPE_CHECKING,
16 Any,
17 Iterable,
18 Literal,
19 Optional,
20 Sequence,
21 TypedDict,
22 Union,
23 overload,
24)
26try:
27 import numpy as np
28except ImportError as e:
29 warnings.warn(
30 f"numpy is not installed, array serialization will not work: \n{e}",
31 ImportWarning,
32 )
34if TYPE_CHECKING:
35 import numpy as np
36 import torch
37 from muutils.json_serialize.json_serialize import JsonSerializer
39from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage]
41# TYPING: pyright complains way too much here
42# pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false
44# Recursive type for nested numeric lists (output of arr.tolist())
45NumericList = typing.Union[
46 typing.List[typing.Union[int, float, bool]],
47 typing.List["NumericList"],
48]
50ArrayMode = Literal[
51 "list",
52 "array_list_meta",
53 "array_hex_meta",
54 "array_b64_meta",
55 "external",
56 "zero_dim",
57]
59# Modes that produce SerializedArrayWithMeta (dict with metadata)
60ArrayModeWithMeta = Literal[
61 "array_list_meta",
62 "array_hex_meta",
63 "array_b64_meta",
64 "zero_dim",
65 "external",
66]
69def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: ignore[reportAny]
70 """get the number of elements in an array"""
71 if isinstance(arr, np.ndarray):
72 return arr.size
73 elif str(type(arr)) == "<class 'torch.Tensor'>": # pyright: ignore[reportUnknownArgumentType, reportAny]
74 assert hasattr(arr, "nelement"), (
75 "torch Tensor does not have nelement() method? this should not happen"
76 ) # pyright: ignore[reportAny]
77 return arr.nelement() # pyright: ignore[reportAny]
78 else:
79 raise TypeError(f"invalid type: {type(arr)}") # pyright: ignore[reportAny]
82class ArrayMetadata(TypedDict):
83 """Metadata for a numpy/torch array"""
85 shape: list[int]
86 dtype: str
87 n_elements: int
90class SerializedArrayWithMeta(TypedDict):
91 """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)"""
93 __muutils_format__: str
94 data: typing.Union[
95 NumericList, str, int, float, bool
96 ] # list, hex str, b64 str, or scalar for zero_dim
97 shape: list[int]
98 dtype: str
99 n_elements: int
102def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny]
103 """get metadata for a numpy array"""
104 return {
105 "shape": list(arr.shape), # pyright: ignore[reportAny]
106 "dtype": (
107 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) # pyright: ignore[reportAny]
108 ),
109 "n_elements": array_n_elements(arr),
110 }
113@overload
114def serialize_array(
115 jser: "JsonSerializer",
116 arr: "Union[np.ndarray, torch.Tensor]",
117 path: str | Sequence[str | int],
118 array_mode: Literal["list"],
119) -> NumericList: ...
120@overload
121def serialize_array(
122 jser: "JsonSerializer",
123 arr: "Union[np.ndarray, torch.Tensor]",
124 path: str | Sequence[str | int],
125 array_mode: ArrayModeWithMeta,
126) -> SerializedArrayWithMeta: ...
127@overload
128def serialize_array(
129 jser: "JsonSerializer",
130 arr: "Union[np.ndarray, torch.Tensor]",
131 path: str | Sequence[str | int],
132 array_mode: None = None,
133) -> SerializedArrayWithMeta | NumericList: ...
134def serialize_array(
135 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821
136 arr: "Union[np.ndarray, torch.Tensor]",
137 path: str | Sequence[str | int], # pyright: ignore[reportUnusedParameter]
138 array_mode: ArrayMode | None = None,
139) -> SerializedArrayWithMeta | NumericList:
140 """serialize a numpy or pytorch array in one of several modes
142 if the object is zero-dimensional, simply get the unique item
144 `array_mode: ArrayMode` can be one of:
145 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
146 - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
147 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
148 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`
150 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
151 ```
152 {
153 _FORMAT_KEY: <array_list_meta|array_hex_meta>,
154 "shape": arr.shape,
155 "dtype": str(arr.dtype),
156 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
157 }
158 ```
160 # Parameters:
161 - `arr : Any` array to serialize
162 - `array_mode : ArrayMode` mode in which to serialize the array
163 (defaults to `None` and inheriting from `jser: JsonSerializer`)
165 # Returns:
166 - `JSONitem`
167 json serialized array
169 # Raises:
170 - `KeyError` : if the array mode is not valid
171 """
173 if array_mode is None:
174 array_mode = jser.array_mode
176 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
177 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) # pyright: ignore[reportUnnecessaryIsInstance]
179 # Handle list mode first (no metadata needed)
180 if array_mode == "list":
181 return arr_np.tolist() # pyright: ignore[reportAny]
183 # For all other modes, compute metadata once
184 metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np)
186 # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe?
188 # handle zero-dimensional arrays
189 if len(arr.shape) == 0:
190 return SerializedArrayWithMeta(
191 __muutils_format__=f"{arr_type}:zero_dim",
192 data=arr.item(), # pyright: ignore[reportAny]
193 shape=metadata["shape"],
194 dtype=metadata["dtype"],
195 n_elements=metadata["n_elements"],
196 )
198 # Handle the metadata modes
199 if array_mode == "array_list_meta":
200 return SerializedArrayWithMeta(
201 __muutils_format__=f"{arr_type}:array_list_meta",
202 data=arr_np.tolist(), # pyright: ignore[reportAny]
203 shape=metadata["shape"],
204 dtype=metadata["dtype"],
205 n_elements=metadata["n_elements"],
206 )
207 elif array_mode == "array_hex_meta":
208 return SerializedArrayWithMeta(
209 __muutils_format__=f"{arr_type}:array_hex_meta",
210 data=arr_np.tobytes().hex(),
211 shape=metadata["shape"],
212 dtype=metadata["dtype"],
213 n_elements=metadata["n_elements"],
214 )
215 elif array_mode == "array_b64_meta":
216 return SerializedArrayWithMeta(
217 __muutils_format__=f"{arr_type}:array_b64_meta",
218 data=base64.b64encode(arr_np.tobytes()).decode(),
219 shape=metadata["shape"],
220 dtype=metadata["dtype"],
221 n_elements=metadata["n_elements"],
222 )
223 else:
224 raise KeyError(f"invalid array_mode: {array_mode}")
227@overload
228def infer_array_mode(
229 arr: SerializedArrayWithMeta,
230) -> ArrayModeWithMeta: ...
231@overload
232def infer_array_mode(arr: NumericList) -> Literal["list"]: ...
233def infer_array_mode(
234 arr: Union[SerializedArrayWithMeta, NumericList],
235) -> ArrayMode:
236 """given a serialized array, infer the mode
238 assumes the array was serialized via `serialize_array()`
239 """
240 return_mode: ArrayMode
241 if isinstance(arr, typing.Mapping):
242 # _FORMAT_KEY always maps to a string
243 fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore
244 if fmt.endswith(":array_list_meta"):
245 arr_data = arr["data"] # ty: ignore[invalid-argument-type]
246 if not isinstance(arr_data, Iterable):
247 raise ValueError(f"invalid list format: {type(arr_data) = }\t{arr}")
248 return_mode = "array_list_meta"
249 elif fmt.endswith(":array_hex_meta"):
250 arr_data = arr["data"] # ty: ignore[invalid-argument-type]
251 if not isinstance(arr_data, str):
252 raise ValueError(f"invalid hex format: {type(arr_data) = }\t{arr}")
253 return_mode = "array_hex_meta"
254 elif fmt.endswith(":array_b64_meta"):
255 arr_data = arr["data"] # ty: ignore[invalid-argument-type]
256 if not isinstance(arr_data, str):
257 raise ValueError(f"invalid b64 format: {type(arr_data) = }\t{arr}")
258 return_mode = "array_b64_meta"
259 elif fmt.endswith(":external"):
260 return_mode = "external"
261 elif fmt.endswith(":zero_dim"):
262 return_mode = "zero_dim"
263 else:
264 raise ValueError(f"invalid format: {arr}")
265 elif isinstance(arr, list): # pyright: ignore[reportUnnecessaryIsInstance]
266 return_mode = "list"
267 else:
268 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") # pyright: ignore[reportUnreachable]
270 return return_mode
273@overload
274def load_array(
275 arr: SerializedArrayWithMeta,
276 array_mode: Optional[ArrayModeWithMeta] = None,
277) -> np.ndarray: ...
278@overload
279def load_array(
280 arr: NumericList,
281 array_mode: Optional[Literal["list"]] = None,
282) -> np.ndarray: ...
283@overload
284def load_array(
285 arr: np.ndarray,
286 array_mode: None = None,
287) -> np.ndarray: ...
288def load_array(
289 arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList],
290 array_mode: Optional[ArrayMode] = None,
291) -> np.ndarray:
292 """load a json-serialized array, infer the mode if not specified"""
293 # return arr if its already a numpy array
294 if isinstance(arr, np.ndarray):
295 assert array_mode is None, (
296 "array_mode should not be specified when loading a numpy array, since that is a no-op"
297 )
298 return arr
300 # try to infer the array_mode
301 array_mode_inferred: ArrayMode = infer_array_mode(arr)
302 if array_mode is None:
303 array_mode = array_mode_inferred
304 elif array_mode != array_mode_inferred:
305 warnings.warn(
306 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
307 )
309 # actually load the array
310 if array_mode == "array_list_meta":
311 assert isinstance(arr, typing.Mapping), (
312 f"invalid list format: {type(arr) = }\n{arr = }"
313 )
314 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore
315 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore
316 raise ValueError(f"invalid shape: {arr}")
317 return data
319 elif array_mode == "array_hex_meta":
320 assert isinstance(arr, typing.Mapping), (
321 f"invalid list format: {type(arr) = }\n{arr = }"
322 )
323 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore
324 return data.reshape(arr["shape"]) # type: ignore
326 elif array_mode == "array_b64_meta":
327 assert isinstance(arr, typing.Mapping), (
328 f"invalid list format: {type(arr) = }\n{arr = }"
329 )
330 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore
331 return data.reshape(arr["shape"]) # type: ignore
333 elif array_mode == "list":
334 assert isinstance(arr, typing.Sequence), (
335 f"invalid list format: {type(arr) = }\n{arr = }"
336 )
337 return np.array(arr) # type: ignore
338 elif array_mode == "external":
339 assert isinstance(arr, typing.Mapping)
340 if "data" not in arr:
341 raise KeyError( # pyright: ignore[reportUnreachable]
342 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
343 )
344 # we can ignore here since we assume ZANJ has taken care of it
345 return arr["data"] # type: ignore[return-value] # pyright: ignore[reportReturnType]
346 elif array_mode == "zero_dim":
347 assert isinstance(arr, typing.Mapping)
348 data = np.array(arr["data"]) # ty: ignore[invalid-argument-type]
349 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore
350 raise ValueError(f"invalid shape: {arr}")
351 return data
352 else:
353 raise ValueError(f"invalid array_mode: {array_mode}") # pyright: ignore[reportUnreachable]