Coverage for tests / unit / json_serialize / test_array_torch.py: 85%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1from __future__ import annotations 

2 

3import numpy as np 

4import pytest 

5import torch 

6 

7from muutils.json_serialize import JsonSerializer 

8from muutils.json_serialize.array import ( 

9 ArrayModeWithMeta, 

10 arr_metadata, 

11 array_n_elements, 

12 load_array, 

13 serialize_array, 

14) 

15from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage] 

16 

17# pylint: disable=missing-class-docstring 

18 

19 

20_WITH_META_ARRAY_MODES: list[ArrayModeWithMeta] = [ 

21 "array_list_meta", 

22 "array_hex_meta", 

23 "array_b64_meta", 

24] 

25 

26 

27def test_arr_metadata_torch(): 

28 """Test arr_metadata() with torch tensors.""" 

29 # 1D tensor 

30 tensor_1d = torch.tensor([1, 2, 3, 4, 5]) 

31 metadata_1d = arr_metadata(tensor_1d) 

32 assert metadata_1d["shape"] == [5] 

33 assert "int64" in metadata_1d["dtype"] # Could be "torch.int64" or "int64" 

34 assert metadata_1d["n_elements"] == 5 

35 

36 # 2D tensor 

37 tensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) 

38 metadata_2d = arr_metadata(tensor_2d) 

39 assert metadata_2d["shape"] == [2, 2] 

40 assert "float32" in metadata_2d["dtype"] 

41 assert metadata_2d["n_elements"] == 4 

42 

43 # 3D tensor 

44 tensor_3d = torch.randn(3, 4, 5, dtype=torch.float64) 

45 metadata_3d = arr_metadata(tensor_3d) 

46 assert metadata_3d["shape"] == [3, 4, 5] 

47 assert "float64" in metadata_3d["dtype"] 

48 assert metadata_3d["n_elements"] == 60 

49 

50 # Zero-dimensional tensor 

51 tensor_0d = torch.tensor(42) 

52 metadata_0d = arr_metadata(tensor_0d) 

53 assert metadata_0d["shape"] == [] 

54 assert metadata_0d["n_elements"] == 1 

55 

56 

57def test_array_n_elements_torch(): 

58 """Test array_n_elements() with torch tensors.""" 

59 assert array_n_elements(torch.tensor([1, 2, 3])) == 3 

60 assert array_n_elements(torch.tensor([[1, 2], [3, 4]])) == 4 

61 assert array_n_elements(torch.randn(2, 3, 4)) == 24 

62 assert array_n_elements(torch.tensor(42)) == 1 

63 

64 

65def test_serialize_load_torch_tensors(): 

66 """Test round-trip serialization of torch tensors.""" 

67 jser = JsonSerializer(array_mode="array_list_meta") 

68 

69 # Test various tensor types 

70 tensors = [ 

71 torch.tensor([1, 2, 3, 4], dtype=torch.int32), 

72 torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float32), 

73 torch.tensor([[[1, 2]], [[3, 4]]], dtype=torch.int64), 

74 torch.tensor([True, False, True], dtype=torch.bool), 

75 ] 

76 

77 for tensor in tensors: 

78 for mode in _WITH_META_ARRAY_MODES: 

79 serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type] 

80 loaded = load_array(serialized) 

81 

82 # Convert to numpy for comparison 

83 tensor_np = tensor.cpu().numpy() 

84 assert np.array_equal(loaded, tensor_np) 

85 assert loaded.shape == tuple(tensor.shape) 

86 

87 

88def test_torch_shape_dtype_preservation(): 

89 """Test that various torch tensor shapes and dtypes are preserved.""" 

90 jser = JsonSerializer(array_mode="array_list_meta") 

91 

92 # Different dtypes 

93 dtype_tests = [ 

94 (torch.tensor([1, 2, 3], dtype=torch.int8), torch.int8), 

95 (torch.tensor([1, 2, 3], dtype=torch.int16), torch.int16), 

96 (torch.tensor([1, 2, 3], dtype=torch.int32), torch.int32), 

97 (torch.tensor([1, 2, 3], dtype=torch.int64), torch.int64), 

98 (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16), torch.float16), 

99 (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), torch.float32), 

100 (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64), torch.float64), 

101 (torch.tensor([True, False, True], dtype=torch.bool), torch.bool), 

102 ] 

103 

104 for tensor, _expected_dtype in dtype_tests: 

105 for mode in _WITH_META_ARRAY_MODES: 

106 serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type] 

107 loaded = load_array(serialized) 

108 

109 # Convert for comparison 

110 tensor_np = tensor.cpu().numpy() 

111 assert np.array_equal(loaded, tensor_np) 

112 assert loaded.dtype.name == tensor_np.dtype.name 

113 

114 

115def test_torch_zero_dim_tensor(): 

116 """Test zero-dimensional torch tensors.""" 

117 jser = JsonSerializer(array_mode="array_list_meta") 

118 

119 tensor_0d = torch.tensor(42) 

120 

121 for mode in _WITH_META_ARRAY_MODES: 

122 serialized = serialize_array(jser, tensor_0d, "test", array_mode=mode) # type: ignore[arg-type] 

123 loaded = load_array(serialized) 

124 

125 # Zero-dim tensors have special handling 

126 assert loaded.shape == tensor_0d.shape 

127 assert np.array_equal(loaded, tensor_0d.cpu().numpy()) 

128 

129 

130def test_torch_edge_cases(): 

131 """Test edge cases with torch tensors.""" 

132 jser = JsonSerializer(array_mode="array_list_meta") 

133 

134 # Empty tensors 

135 empty_1d = torch.tensor([], dtype=torch.float32) 

136 serialized = serialize_array(jser, empty_1d, "test", array_mode="array_list_meta") 

137 loaded = load_array(serialized) 

138 assert loaded.shape == (0,) 

139 

140 # Tensors with special values 

141 special_tensor = torch.tensor( 

142 [float("inf"), float("-inf"), float("nan"), 0.0, -0.0] 

143 ) 

144 for mode in _WITH_META_ARRAY_MODES: 

145 serialized = serialize_array(jser, special_tensor, "test", array_mode=mode) # type: ignore[arg-type] 

146 loaded = load_array(serialized) 

147 

148 # Check special values 

149 assert np.isinf(loaded[0]) and loaded[0] > 0 # pyright: ignore[reportAny] 

150 assert np.isinf(loaded[1]) and loaded[1] < 0 # pyright: ignore[reportAny] 

151 assert np.isnan(loaded[2]) # pyright: ignore[reportAny] 

152 

153 # Large tensor 

154 large_tensor = torch.randn(100, 100) 

155 serialized = serialize_array( 

156 jser, large_tensor, "test", array_mode="array_b64_meta" 

157 ) 

158 loaded = load_array(serialized) 

159 assert np.allclose(loaded, large_tensor.cpu().numpy()) 

160 

161 

162def test_torch_gpu_tensors(): 

163 """Test serialization of GPU tensors (if CUDA is available).""" 

164 if not torch.cuda.is_available(): 

165 # TYPING: ty bug on python <= 3.9 

166 pytest.skip("CUDA not available") # ty: ignore[arg-type,invalid-argument-type,too-many-positional-arguments] 

167 

168 jser = JsonSerializer(array_mode="array_list_meta") 

169 

170 # Create GPU tensor 

171 tensor_gpu = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device="cuda") 

172 

173 for mode in _WITH_META_ARRAY_MODES: 

174 # Need to move to CPU first for numpy conversion 

175 tensor_cpu_torch = tensor_gpu.cpu() 

176 serialized = serialize_array(jser, tensor_cpu_torch, "test", array_mode=mode) # type: ignore[arg-type] 

177 loaded = load_array(serialized) 

178 

179 # Should match the CPU version 

180 tensor_cpu = tensor_gpu.cpu().numpy() 

181 assert np.array_equal(loaded, tensor_cpu) 

182 

183 

184def test_torch_serialization_integration(): 

185 """Test torch tensors integrated with JsonSerializer in complex structures.""" 

186 jser = JsonSerializer(array_mode="array_list_meta") 

187 

188 # Mixed structure with torch tensors 

189 data = { 

190 "model_weights": torch.randn(10, 5), 

191 "biases": torch.randn(5), 

192 "metadata": {"epochs": 10, "lr": 0.001}, 

193 "history": [ 

194 {"loss": torch.tensor(0.5), "accuracy": torch.tensor(0.95)}, 

195 {"loss": torch.tensor(0.3), "accuracy": torch.tensor(0.97)}, 

196 ], 

197 } 

198 

199 serialized = jser.json_serialize(data) 

200 assert isinstance(serialized, dict) 

201 

202 # Check structure is preserved 

203 assert isinstance(serialized["model_weights"], dict) 

204 assert _FORMAT_KEY in serialized["model_weights"] 

205 assert serialized["model_weights"]["shape"] == [10, 5] # pyright: ignore[reportGeneralTypeIssues] 

206 

207 serialized_biases = serialized["biases"] 

208 assert isinstance(serialized_biases, dict) 

209 assert serialized_biases["shape"] == [5] # pyright: ignore[reportGeneralTypeIssues] 

210 

211 serialized_metadata = serialized["metadata"] 

212 assert isinstance(serialized_metadata, dict) 

213 assert serialized_metadata["epochs"] == 10 # pyright: ignore[reportGeneralTypeIssues] 

214 

215 # Check nested tensors 

216 serialized_history = serialized["history"] 

217 assert isinstance(serialized_history, list) 

218 history_item_0 = serialized_history[0] 

219 assert isinstance(history_item_0, dict) 

220 history_item_0_loss = history_item_0["loss"] # pyright: ignore[reportGeneralTypeIssues] 

221 assert isinstance(history_item_0_loss, dict) 

222 assert _FORMAT_KEY in history_item_0_loss 

223 

224 

225def test_mixed_numpy_torch(): 

226 """Test that both numpy arrays and torch tensors can coexist in serialization.""" 

227 jser = JsonSerializer(array_mode="array_list_meta") 

228 

229 data = { 

230 "numpy_array": np.array([1, 2, 3]), 

231 "torch_tensor": torch.tensor([4, 5, 6]), 

232 "nested": { 

233 "np": np.array([[1, 2]]), 

234 "torch": torch.tensor([[3, 4]]), 

235 }, 

236 } 

237 

238 serialized = jser.json_serialize(data) 

239 assert isinstance(serialized, dict) 

240 

241 # Both should be serialized as dicts with metadata 

242 assert isinstance(serialized["numpy_array"], dict) 

243 assert isinstance(serialized["torch_tensor"], dict) 

244 assert _FORMAT_KEY in serialized["numpy_array"] 

245 assert _FORMAT_KEY in serialized["torch_tensor"] 

246 

247 # Check format strings identify the type 

248 numpy_format = serialized["numpy_array"][_FORMAT_KEY] 

249 assert isinstance(numpy_format, str) 

250 assert "numpy" in numpy_format 

251 

252 torch_format = serialized["torch_tensor"][_FORMAT_KEY] 

253 assert isinstance(torch_format, str) 

254 assert "torch" in torch_format