Coverage for muutils/json_serialize/array.py: 62%
95 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""this utilities module handles serialization and loading of numpy and torch arrays as json
3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
4- `array_b64_meta` is the most efficient, but is not human readable.
5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ)
7"""
9from __future__ import annotations
11import base64
12import typing
13import warnings
14from typing import Any, Iterable, Literal, Optional, Sequence
16try:
17 import numpy as np
18except ImportError as e:
19 warnings.warn(
20 f"numpy is not installed, array serialization will not work: \n{e}",
21 ImportWarning,
22 )
24from muutils.json_serialize.util import _FORMAT_KEY, JSONitem
26# pylint: disable=unused-argument
28ArrayMode = Literal[
29 "list",
30 "array_list_meta",
31 "array_hex_meta",
32 "array_b64_meta",
33 "external",
34 "zero_dim",
35]
38def array_n_elements(arr) -> int: # type: ignore[name-defined]
39 """get the number of elements in an array"""
40 if isinstance(arr, np.ndarray):
41 return arr.size
42 elif str(type(arr)) == "<class 'torch.Tensor'>":
43 return arr.nelement()
44 else:
45 raise TypeError(f"invalid type: {type(arr)}")
48def arr_metadata(arr) -> dict[str, list[int] | str | int]:
49 """get metadata for a numpy array"""
50 return {
51 "shape": list(arr.shape),
52 "dtype": (
53 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
54 ),
55 "n_elements": array_n_elements(arr),
56 }
59def serialize_array(
60 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821
61 arr: np.ndarray,
62 path: str | Sequence[str | int],
63 array_mode: ArrayMode | None = None,
64) -> JSONitem:
65 """serialize a numpy or pytorch array in one of several modes
67 if the object is zero-dimensional, simply get the unique item
69 `array_mode: ArrayMode` can be one of:
70 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
71 - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
72 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
73 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`
75 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
76 ```
77 {
78 _FORMAT_KEY: <array_list_meta|array_hex_meta>,
79 "shape": arr.shape,
80 "dtype": str(arr.dtype),
81 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
82 }
83 ```
85 # Parameters:
86 - `arr : Any` array to serialize
87 - `array_mode : ArrayMode` mode in which to serialize the array
88 (defaults to `None` and inheriting from `jser: JsonSerializer`)
90 # Returns:
91 - `JSONitem`
92 json serialized array
94 # Raises:
95 - `KeyError` : if the array mode is not valid
96 """
98 if array_mode is None:
99 array_mode = jser.array_mode
101 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
102 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr)
104 # handle zero-dimensional arrays
105 if len(arr.shape) == 0:
106 return {
107 _FORMAT_KEY: f"{arr_type}:zero_dim",
108 "data": arr.item(),
109 **arr_metadata(arr),
110 }
112 if array_mode == "array_list_meta":
113 return {
114 _FORMAT_KEY: f"{arr_type}:array_list_meta",
115 "data": arr_np.tolist(),
116 **arr_metadata(arr_np),
117 }
118 elif array_mode == "list":
119 return arr_np.tolist()
120 elif array_mode == "array_hex_meta":
121 return {
122 _FORMAT_KEY: f"{arr_type}:array_hex_meta",
123 "data": arr_np.tobytes().hex(),
124 **arr_metadata(arr_np),
125 }
126 elif array_mode == "array_b64_meta":
127 return {
128 _FORMAT_KEY: f"{arr_type}:array_b64_meta",
129 "data": base64.b64encode(arr_np.tobytes()).decode(),
130 **arr_metadata(arr_np),
131 }
132 else:
133 raise KeyError(f"invalid array_mode: {array_mode}")
136def infer_array_mode(arr: JSONitem) -> ArrayMode:
137 """given a serialized array, infer the mode
139 assumes the array was serialized via `serialize_array()`
140 """
141 if isinstance(arr, typing.Mapping):
142 fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore
143 if fmt.endswith(":array_list_meta"):
144 if not isinstance(arr["data"], Iterable):
145 raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}")
146 return "array_list_meta"
147 elif fmt.endswith(":array_hex_meta"):
148 if not isinstance(arr["data"], str):
149 raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}")
150 return "array_hex_meta"
151 elif fmt.endswith(":array_b64_meta"):
152 if not isinstance(arr["data"], str):
153 raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}")
154 return "array_b64_meta"
155 elif fmt.endswith(":external"):
156 return "external"
157 elif fmt.endswith(":zero_dim"):
158 return "zero_dim"
159 else:
160 raise ValueError(f"invalid format: {arr}")
161 elif isinstance(arr, list):
162 return "list"
163 else:
164 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }")
167def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
168 """load a json-serialized array, infer the mode if not specified"""
169 # return arr if its already a numpy array
170 if isinstance(arr, np.ndarray) and array_mode is None:
171 return arr
173 # try to infer the array_mode
174 array_mode_inferred: ArrayMode = infer_array_mode(arr)
175 if array_mode is None:
176 array_mode = array_mode_inferred
177 elif array_mode != array_mode_inferred:
178 warnings.warn(
179 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
180 )
182 # actually load the array
183 if array_mode == "array_list_meta":
184 assert isinstance(
185 arr, typing.Mapping
186 ), f"invalid list format: {type(arr) = }\n{arr = }"
187 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore
188 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore
189 raise ValueError(f"invalid shape: {arr}")
190 return data
192 elif array_mode == "array_hex_meta":
193 assert isinstance(
194 arr, typing.Mapping
195 ), f"invalid list format: {type(arr) = }\n{arr = }"
196 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore
197 return data.reshape(arr["shape"]) # type: ignore
199 elif array_mode == "array_b64_meta":
200 assert isinstance(
201 arr, typing.Mapping
202 ), f"invalid list format: {type(arr) = }\n{arr = }"
203 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore
204 return data.reshape(arr["shape"]) # type: ignore
206 elif array_mode == "list":
207 assert isinstance(
208 arr, typing.Sequence
209 ), f"invalid list format: {type(arr) = }\n{arr = }"
210 return np.array(arr) # type: ignore
211 elif array_mode == "external":
212 # assume ZANJ has taken care of it
213 assert isinstance(arr, typing.Mapping)
214 if "data" not in arr:
215 raise KeyError(
216 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
217 )
218 return arr["data"]
219 elif array_mode == "zero_dim":
220 assert isinstance(arr, typing.Mapping)
221 data = np.array(arr["data"])
222 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore
223 raise ValueError(f"invalid shape: {arr}")
224 return data
225 else:
226 raise ValueError(f"invalid array_mode: {array_mode}")