Coverage for tests / unit / json_serialize / test_array.py: 100%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1import numpy as np 

2import pytest 

3 

4from muutils.json_serialize import JsonSerializer 

5from muutils.json_serialize.array import ( 

6 ArrayMode, 

7 ArrayModeWithMeta, 

8 arr_metadata, 

9 array_n_elements, 

10 load_array, 

11 serialize_array, 

12) 

13from muutils.json_serialize.types import _FORMAT_KEY 

14 

15# pylint: disable=missing-class-docstring 

16 

17 

18class TestArray: 

19 def setup_method(self): 

20 self.array_1d = np.array([1, 2, 3]) 

21 self.array_2d = np.array([[1, 2], [3, 4]]) 

22 self.array_3d = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int64) 

23 self.array_zero_dim = np.array(42) 

24 self.jser = JsonSerializer(array_mode="list") 

25 

26 def test_array_n_elements(self): 

27 assert array_n_elements(self.array_1d) == 3 

28 assert array_n_elements(self.array_2d) == 4 

29 assert array_n_elements(self.array_3d) == 8 

30 assert array_n_elements(self.array_zero_dim) == 1 

31 

32 def test_arr_metadata(self): 

33 metadata = arr_metadata(self.array_3d) 

34 assert metadata["shape"] == [2, 2, 2] 

35 assert metadata["dtype"] == "int64" 

36 assert metadata["n_elements"] == 8 

37 

38 @pytest.mark.parametrize( 

39 "array_mode,expected_type", 

40 [ 

41 ("list", list), 

42 ("array_list_meta", dict), 

43 ("array_hex_meta", dict), 

44 ("array_b64_meta", dict), 

45 ], 

46 ) 

47 def test_serialize_array(self, array_mode: ArrayMode, expected_type: type): 

48 result = serialize_array( 

49 self.jser, self.array_2d, "test_path", array_mode=array_mode 

50 ) 

51 assert isinstance(result, expected_type) 

52 

53 def test_load_array(self): 

54 serialized_array = serialize_array( 

55 self.jser, self.array_3d, "test_path", array_mode="array_list_meta" 

56 ) 

57 loaded_array = load_array(serialized_array, array_mode="array_list_meta") 

58 assert np.array_equal(loaded_array, self.array_3d) 

59 

60 @pytest.mark.parametrize( 

61 "array_mode", 

62 ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"], 

63 ) 

64 def test_serialize_load_integration(self, array_mode: ArrayMode): 

65 for array in [self.array_1d, self.array_2d, self.array_3d]: 

66 serialized_array = serialize_array( 

67 self.jser, 

68 array, 

69 "test_path", 

70 array_mode=array_mode, 

71 ) 

72 # The overload combinations for serialize_array -> load_array are complex 

73 # since array_mode determines both the serialized type and load method 

74 loaded_array = load_array(serialized_array, array_mode=array_mode) # type: ignore[call-overload, arg-type] 

75 assert np.array_equal(loaded_array, array) 

76 

77 def test_serialize_load_list(self): 

78 """Test serialize/load with 'list' mode - separate function for type safety.""" 

79 for array in [self.array_1d, self.array_2d, self.array_3d]: 

80 serialized_array = serialize_array( 

81 self.jser, array, "test_path", array_mode="list" 

82 ) 

83 loaded_array = load_array(serialized_array, array_mode="list") 

84 assert np.array_equal(loaded_array, array) 

85 

86 def test_serialize_load_array_list_meta(self): 

87 """Test serialize/load with 'array_list_meta' mode - separate function for type safety.""" 

88 for array in [self.array_1d, self.array_2d, self.array_3d]: 

89 serialized_array = serialize_array( 

90 self.jser, array, "test_path", array_mode="array_list_meta" 

91 ) 

92 loaded_array = load_array(serialized_array, array_mode="array_list_meta") 

93 assert np.array_equal(loaded_array, array) 

94 

95 def test_serialize_load_array_hex_meta(self): 

96 """Test serialize/load with 'array_hex_meta' mode - separate function for type safety.""" 

97 for array in [self.array_1d, self.array_2d, self.array_3d]: 

98 serialized_array = serialize_array( 

99 self.jser, array, "test_path", array_mode="array_hex_meta" 

100 ) 

101 loaded_array = load_array(serialized_array, array_mode="array_hex_meta") 

102 assert np.array_equal(loaded_array, array) 

103 

104 def test_serialize_load_array_b64_meta(self): 

105 """Test serialize/load with 'array_b64_meta' mode - separate function for type safety.""" 

106 for array in [self.array_1d, self.array_2d, self.array_3d]: 

107 serialized_array = serialize_array( 

108 self.jser, array, "test_path", array_mode="array_b64_meta" 

109 ) 

110 loaded_array = load_array(serialized_array, array_mode="array_b64_meta") 

111 assert np.array_equal(loaded_array, array) 

112 

113 # TODO: do we even want to support "list" mode for zero-dim arrays? 

114 @pytest.mark.parametrize( 

115 "array_mode", 

116 ["array_list_meta", "array_hex_meta", "array_b64_meta"], 

117 ) 

118 def test_serialize_load_zero_dim(self, array_mode: ArrayModeWithMeta): 

119 serialized_array = serialize_array( 

120 self.jser, 

121 self.array_zero_dim, 

122 "test_path", 

123 array_mode=array_mode, 

124 ) 

125 loaded_array = load_array(serialized_array) 

126 assert np.array_equal(loaded_array, self.array_zero_dim) 

127 

128 

129@pytest.mark.parametrize( 

130 "mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"] 

131) 

132def test_array_shape_dtype_preservation(mode: ArrayModeWithMeta): 

133 """Test that various shapes and dtypes are preserved through serialization.""" 

134 # Test different shapes 

135 shapes_and_arrays: list[tuple[np.ndarray, str]] = [ 

136 (np.array([1, 2, 3], dtype=np.int32), "1D int32"), 

137 (np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float32), "2D float32"), 

138 (np.array([[[1]], [[2]]], dtype=np.int8), "3D int8"), 

139 (np.array([[[[1, 2, 3, 4]]]], dtype=np.int16), "4D int16"), 

140 ] 

141 

142 # Test different dtypes 

143 dtype_tests: list[tuple[np.ndarray, type[np.generic]]] = [ 

144 (np.array([1, 2, 3], dtype=np.int8), np.int8), 

145 (np.array([1, 2, 3], dtype=np.int16), np.int16), 

146 (np.array([1, 2, 3], dtype=np.int32), np.int32), 

147 (np.array([1, 2, 3], dtype=np.int64), np.int64), 

148 (np.array([1.0, 2.0, 3.0], dtype=np.float16), np.float16), 

149 (np.array([1.0, 2.0, 3.0], dtype=np.float32), np.float32), 

150 (np.array([1.0, 2.0, 3.0], dtype=np.float64), np.float64), 

151 (np.array([True, False, True], dtype=np.bool_), np.bool_), 

152 ] 

153 

154 jser = JsonSerializer(array_mode="array_list_meta") 

155 

156 # Test shapes preservation 

157 for arr, description in shapes_and_arrays: 

158 serialized = serialize_array(jser, arr, "test", array_mode=mode) 

159 loaded = load_array(serialized) 

160 assert loaded.shape == arr.shape, f"Shape mismatch for {description} in {mode}" 

161 assert loaded.dtype == arr.dtype, f"Dtype mismatch for {description} in {mode}" 

162 assert np.array_equal(loaded, arr), f"Data mismatch for {description} in {mode}" 

163 

164 # Test dtypes preservation 

165 for arr, expected_dtype in dtype_tests: 

166 serialized = serialize_array(jser, arr, "test", array_mode=mode) 

167 loaded = load_array(serialized) 

168 assert loaded.dtype == expected_dtype, f"Dtype not preserved: {mode}" 

169 assert np.array_equal(loaded, arr), f"Data not preserved: {mode}" 

170 

171 

172def test_array_serialization_handlers(): 

173 """Test integration with JsonSerializer - ensure arrays are serialized correctly when part of larger objects.""" 

174 # Test that JsonSerializer properly handles arrays in different contexts 

175 jser = JsonSerializer(array_mode="array_list_meta") 

176 

177 # Array in a dict 

178 data_dict = { 

179 "metadata": {"name": "test"}, 

180 "array": np.array([1, 2, 3, 4]), 

181 "nested": {"inner_array": np.array([[1, 2], [3, 4]])}, 

182 } 

183 

184 serialized = jser.json_serialize(data_dict) 

185 assert isinstance(serialized, dict) 

186 serialized_array = serialized["array"] 

187 assert isinstance(serialized_array, dict) 

188 assert _FORMAT_KEY in serialized_array 

189 assert serialized_array["shape"] == [4] 

190 

191 # Array in a list 

192 data_list = [ 

193 {"value": 1}, 

194 np.array([10, 20, 30]), 

195 {"value": 2, "data": np.array([[1, 2]])}, 

196 ] 

197 

198 serialized_list = jser.json_serialize(data_list) 

199 assert isinstance(serialized_list, list) 

200 serialized_list_item = serialized_list[1] 

201 assert isinstance(serialized_list_item, dict) 

202 assert _FORMAT_KEY in serialized_list_item 

203 

204 # Test different array modes 

205 for mode in ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"]: 

206 jser_mode = JsonSerializer(array_mode=mode) # type: ignore[arg-type] 

207 arr = np.array([[1, 2, 3], [4, 5, 6]]) 

208 result = jser_mode.json_serialize(arr) 

209 

210 if mode == "list": 

211 assert isinstance(result, list) 

212 else: 

213 assert isinstance(result, dict) 

214 assert _FORMAT_KEY in result 

215 

216 

217@pytest.mark.parametrize( 

218 "mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"] 

219) 

220def test_array_edge_cases(mode: ArrayModeWithMeta): 

221 """Test edge cases: empty arrays, unusual dtypes, and boundary conditions.""" 

222 jser = JsonSerializer(array_mode="array_list_meta") 

223 

224 # Empty arrays with different shapes 

225 empty_arrays: list[np.ndarray] = [ 

226 np.array([], dtype=np.int32), 

227 np.array([[], []], dtype=np.float32).reshape(2, 0), 

228 np.array([[]], dtype=np.int64).reshape(1, 1, 0), 

229 ] 

230 

231 for empty_arr in empty_arrays: 

232 serialized = serialize_array(jser, empty_arr, "test", array_mode=mode) 

233 loaded = load_array(serialized) 

234 assert loaded.shape == empty_arr.shape 

235 assert loaded.dtype == empty_arr.dtype 

236 assert np.array_equal(loaded, empty_arr) 

237 

238 # Complex dtypes 

239 complex_arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) 

240 serialized = serialize_array( 

241 jser, complex_arr, "test", array_mode="array_list_meta" 

242 ) 

243 loaded = load_array(serialized) 

244 assert loaded.dtype == np.complex64 

245 assert np.array_equal(loaded, complex_arr) 

246 

247 # Large arrays (test that serialization doesn't break) 

248 large_arr = np.random.rand(100, 100) 

249 serialized = serialize_array(jser, large_arr, "test", array_mode=mode) 

250 loaded = load_array(serialized) 

251 assert np.allclose(loaded, large_arr) 

252 

253 # Arrays with special values 

254 special_arr = np.array([np.inf, -np.inf, np.nan, 0.0, -0.0], dtype=np.float64) 

255 serialized = serialize_array(jser, special_arr, "test", array_mode=mode) 

256 loaded = load_array(serialized) 

257 # Use special comparison for NaN 

258 assert np.isnan(loaded[2]) and np.isnan(special_arr[2]) 

259 assert np.array_equal(loaded[:2], special_arr[:2]) # inf values 

260 assert np.array_equal(loaded[3:], special_arr[3:]) # zeros