Coverage for tests / unit / json_serialize / test_array_torch.py: 85%
125 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 02:51 -0700
1from __future__ import annotations
3import numpy as np
4import pytest
5import torch
7from muutils.json_serialize import JsonSerializer
8from muutils.json_serialize.array import (
9 ArrayModeWithMeta,
10 arr_metadata,
11 array_n_elements,
12 load_array,
13 serialize_array,
14)
15from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage]
17# pylint: disable=missing-class-docstring
20_WITH_META_ARRAY_MODES: list[ArrayModeWithMeta] = [
21 "array_list_meta",
22 "array_hex_meta",
23 "array_b64_meta",
24]
27def test_arr_metadata_torch():
28 """Test arr_metadata() with torch tensors."""
29 # 1D tensor
30 tensor_1d = torch.tensor([1, 2, 3, 4, 5])
31 metadata_1d = arr_metadata(tensor_1d)
32 assert metadata_1d["shape"] == [5]
33 assert "int64" in metadata_1d["dtype"] # Could be "torch.int64" or "int64"
34 assert metadata_1d["n_elements"] == 5
36 # 2D tensor
37 tensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
38 metadata_2d = arr_metadata(tensor_2d)
39 assert metadata_2d["shape"] == [2, 2]
40 assert "float32" in metadata_2d["dtype"]
41 assert metadata_2d["n_elements"] == 4
43 # 3D tensor
44 tensor_3d = torch.randn(3, 4, 5, dtype=torch.float64)
45 metadata_3d = arr_metadata(tensor_3d)
46 assert metadata_3d["shape"] == [3, 4, 5]
47 assert "float64" in metadata_3d["dtype"]
48 assert metadata_3d["n_elements"] == 60
50 # Zero-dimensional tensor
51 tensor_0d = torch.tensor(42)
52 metadata_0d = arr_metadata(tensor_0d)
53 assert metadata_0d["shape"] == []
54 assert metadata_0d["n_elements"] == 1
57def test_array_n_elements_torch():
58 """Test array_n_elements() with torch tensors."""
59 assert array_n_elements(torch.tensor([1, 2, 3])) == 3
60 assert array_n_elements(torch.tensor([[1, 2], [3, 4]])) == 4
61 assert array_n_elements(torch.randn(2, 3, 4)) == 24
62 assert array_n_elements(torch.tensor(42)) == 1
65def test_serialize_load_torch_tensors():
66 """Test round-trip serialization of torch tensors."""
67 jser = JsonSerializer(array_mode="array_list_meta")
69 # Test various tensor types
70 tensors = [
71 torch.tensor([1, 2, 3, 4], dtype=torch.int32),
72 torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float32),
73 torch.tensor([[[1, 2]], [[3, 4]]], dtype=torch.int64),
74 torch.tensor([True, False, True], dtype=torch.bool),
75 ]
77 for tensor in tensors:
78 for mode in _WITH_META_ARRAY_MODES:
79 serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type]
80 loaded = load_array(serialized)
82 # Convert to numpy for comparison
83 tensor_np = tensor.cpu().numpy()
84 assert np.array_equal(loaded, tensor_np)
85 assert loaded.shape == tuple(tensor.shape)
88def test_torch_shape_dtype_preservation():
89 """Test that various torch tensor shapes and dtypes are preserved."""
90 jser = JsonSerializer(array_mode="array_list_meta")
92 # Different dtypes
93 dtype_tests = [
94 (torch.tensor([1, 2, 3], dtype=torch.int8), torch.int8),
95 (torch.tensor([1, 2, 3], dtype=torch.int16), torch.int16),
96 (torch.tensor([1, 2, 3], dtype=torch.int32), torch.int32),
97 (torch.tensor([1, 2, 3], dtype=torch.int64), torch.int64),
98 (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16), torch.float16),
99 (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), torch.float32),
100 (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64), torch.float64),
101 (torch.tensor([True, False, True], dtype=torch.bool), torch.bool),
102 ]
104 for tensor, _expected_dtype in dtype_tests:
105 for mode in _WITH_META_ARRAY_MODES:
106 serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type]
107 loaded = load_array(serialized)
109 # Convert for comparison
110 tensor_np = tensor.cpu().numpy()
111 assert np.array_equal(loaded, tensor_np)
112 assert loaded.dtype.name == tensor_np.dtype.name
115def test_torch_zero_dim_tensor():
116 """Test zero-dimensional torch tensors."""
117 jser = JsonSerializer(array_mode="array_list_meta")
119 tensor_0d = torch.tensor(42)
121 for mode in _WITH_META_ARRAY_MODES:
122 serialized = serialize_array(jser, tensor_0d, "test", array_mode=mode) # type: ignore[arg-type]
123 loaded = load_array(serialized)
125 # Zero-dim tensors have special handling
126 assert loaded.shape == tensor_0d.shape
127 assert np.array_equal(loaded, tensor_0d.cpu().numpy())
130def test_torch_edge_cases():
131 """Test edge cases with torch tensors."""
132 jser = JsonSerializer(array_mode="array_list_meta")
134 # Empty tensors
135 empty_1d = torch.tensor([], dtype=torch.float32)
136 serialized = serialize_array(jser, empty_1d, "test", array_mode="array_list_meta")
137 loaded = load_array(serialized)
138 assert loaded.shape == (0,)
140 # Tensors with special values
141 special_tensor = torch.tensor(
142 [float("inf"), float("-inf"), float("nan"), 0.0, -0.0]
143 )
144 for mode in _WITH_META_ARRAY_MODES:
145 serialized = serialize_array(jser, special_tensor, "test", array_mode=mode) # type: ignore[arg-type]
146 loaded = load_array(serialized)
148 # Check special values
149 assert np.isinf(loaded[0]) and loaded[0] > 0 # pyright: ignore[reportAny]
150 assert np.isinf(loaded[1]) and loaded[1] < 0 # pyright: ignore[reportAny]
151 assert np.isnan(loaded[2]) # pyright: ignore[reportAny]
153 # Large tensor
154 large_tensor = torch.randn(100, 100)
155 serialized = serialize_array(
156 jser, large_tensor, "test", array_mode="array_b64_meta"
157 )
158 loaded = load_array(serialized)
159 assert np.allclose(loaded, large_tensor.cpu().numpy())
162def test_torch_gpu_tensors():
163 """Test serialization of GPU tensors (if CUDA is available)."""
164 if not torch.cuda.is_available():
165 # TYPING: ty bug on python <= 3.9
166 pytest.skip("CUDA not available") # ty: ignore[arg-type,invalid-argument-type,too-many-positional-arguments]
168 jser = JsonSerializer(array_mode="array_list_meta")
170 # Create GPU tensor
171 tensor_gpu = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device="cuda")
173 for mode in _WITH_META_ARRAY_MODES:
174 # Need to move to CPU first for numpy conversion
175 tensor_cpu_torch = tensor_gpu.cpu()
176 serialized = serialize_array(jser, tensor_cpu_torch, "test", array_mode=mode) # type: ignore[arg-type]
177 loaded = load_array(serialized)
179 # Should match the CPU version
180 tensor_cpu = tensor_gpu.cpu().numpy()
181 assert np.array_equal(loaded, tensor_cpu)
184def test_torch_serialization_integration():
185 """Test torch tensors integrated with JsonSerializer in complex structures."""
186 jser = JsonSerializer(array_mode="array_list_meta")
188 # Mixed structure with torch tensors
189 data = {
190 "model_weights": torch.randn(10, 5),
191 "biases": torch.randn(5),
192 "metadata": {"epochs": 10, "lr": 0.001},
193 "history": [
194 {"loss": torch.tensor(0.5), "accuracy": torch.tensor(0.95)},
195 {"loss": torch.tensor(0.3), "accuracy": torch.tensor(0.97)},
196 ],
197 }
199 serialized = jser.json_serialize(data)
200 assert isinstance(serialized, dict)
202 # Check structure is preserved
203 assert isinstance(serialized["model_weights"], dict)
204 assert _FORMAT_KEY in serialized["model_weights"]
205 assert serialized["model_weights"]["shape"] == [10, 5] # pyright: ignore[reportGeneralTypeIssues]
207 serialized_biases = serialized["biases"]
208 assert isinstance(serialized_biases, dict)
209 assert serialized_biases["shape"] == [5] # pyright: ignore[reportGeneralTypeIssues]
211 serialized_metadata = serialized["metadata"]
212 assert isinstance(serialized_metadata, dict)
213 assert serialized_metadata["epochs"] == 10 # pyright: ignore[reportGeneralTypeIssues]
215 # Check nested tensors
216 serialized_history = serialized["history"]
217 assert isinstance(serialized_history, list)
218 history_item_0 = serialized_history[0]
219 assert isinstance(history_item_0, dict)
220 history_item_0_loss = history_item_0["loss"] # pyright: ignore[reportGeneralTypeIssues]
221 assert isinstance(history_item_0_loss, dict)
222 assert _FORMAT_KEY in history_item_0_loss
225def test_mixed_numpy_torch():
226 """Test that both numpy arrays and torch tensors can coexist in serialization."""
227 jser = JsonSerializer(array_mode="array_list_meta")
229 data = {
230 "numpy_array": np.array([1, 2, 3]),
231 "torch_tensor": torch.tensor([4, 5, 6]),
232 "nested": {
233 "np": np.array([[1, 2]]),
234 "torch": torch.tensor([[3, 4]]),
235 },
236 }
238 serialized = jser.json_serialize(data)
239 assert isinstance(serialized, dict)
241 # Both should be serialized as dicts with metadata
242 assert isinstance(serialized["numpy_array"], dict)
243 assert isinstance(serialized["torch_tensor"], dict)
244 assert _FORMAT_KEY in serialized["numpy_array"]
245 assert _FORMAT_KEY in serialized["torch_tensor"]
247 # Check format strings identify the type
248 numpy_format = serialized["numpy_array"][_FORMAT_KEY]
249 assert isinstance(numpy_format, str)
250 assert "numpy" in numpy_format
252 torch_format = serialized["torch_tensor"][_FORMAT_KEY]
253 assert isinstance(torch_format, str)
254 assert "torch" in torch_format