Coverage for tests/unit/json_serialize/serializable_dataclass/test_helpers.py: 100%
102 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1from __future__ import annotations
3from dataclasses import dataclass
5import numpy as np
6import torch
8from muutils.json_serialize.serializable_dataclass import array_safe_eq, dc_eq
11def test_array_safe_eq():
12 assert array_safe_eq(np.array([1, 2, 3]), np.array([1, 2, 3]))
13 assert not array_safe_eq(np.array([1, 2, 3]), np.array([4, 5, 6]))
14 assert array_safe_eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]))
15 assert not array_safe_eq(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]))
16 assert array_safe_eq(np.array([]), np.array([]))
17 assert array_safe_eq(np.array([[]]), np.array([[]]))
18 assert array_safe_eq([], [])
19 assert array_safe_eq(dict(), dict())
20 assert array_safe_eq([1, 2, 3], [1, 2, 3])
21 assert array_safe_eq([np.array([1, 2, 3])], [np.array([1, 2, 3])])
22 assert not array_safe_eq([], [np.array([1, 2, 3])])
23 assert not array_safe_eq([[np.array([1, 2, 3])]], [np.array([1, 2, 3])])
24 assert array_safe_eq(
25 [np.array([1, 2, 3]), torch.tensor([1, 2, 3])],
26 [np.array([1, 2, 3]), torch.tensor([1, 2, 3])],
27 )
28 assert array_safe_eq([np.array([1, 2, 3]), []], [np.array([1, 2, 3]), []])
29 assert not array_safe_eq([[], np.array([1, 2, 3])], [np.array([1, 2, 3]), []])
32def test_dc_eq_case1():
33 @dataclass(eq=False)
34 class TestClass:
35 a: int
36 b: np.ndarray
37 c: torch.Tensor
38 e: list[int]
39 f: dict[str, int]
41 instance1 = TestClass(
42 a=1,
43 b=np.array([1, 2, 3]),
44 c=torch.tensor([1, 2, 3]),
45 e=[1, 2, 3],
46 f={"key1": 1, "key2": 2},
47 )
49 instance2 = TestClass(
50 a=1,
51 b=np.array([1, 2, 3]),
52 c=torch.tensor([1, 2, 3]),
53 e=[1, 2, 3],
54 f={"key1": 1, "key2": 2},
55 )
57 assert dc_eq(instance1, instance2)
60def test_dc_eq_case2():
61 @dataclass(eq=False)
62 class TestClass:
63 a: int
64 b: np.ndarray
65 c: torch.Tensor
66 e: list[int]
67 f: dict[str, int]
69 instance1 = TestClass(
70 a=1,
71 b=np.array([1, 2, 3]),
72 c=torch.tensor([1, 2, 3]),
73 e=[1, 2, 3],
74 f={"key1": 1, "key2": 2},
75 )
77 instance2 = TestClass(
78 a=1,
79 b=np.array([4, 5, 6]),
80 c=torch.tensor([1, 2, 3]),
81 e=[1, 2, 3],
82 f={"key1": 1, "key2": 2},
83 )
85 assert not dc_eq(instance1, instance2)
88def test_dc_eq_case3():
89 @dataclass(eq=False)
90 class TestClass:
91 a: int
92 b: np.ndarray
93 c: torch.Tensor
94 e: list[int]
95 f: dict[str, int]
97 instance1 = TestClass(
98 a=1,
99 b=np.array([1, 2, 3]),
100 c=torch.tensor([1, 2, 3]),
101 e=[1, 2, 3],
102 f={"key1": 1, "key2": 2},
103 )
105 instance2 = TestClass(
106 a=2,
107 b=np.array([1, 2, 3]),
108 c=torch.tensor([1, 2, 3]),
109 e=[1, 2, 3],
110 f={"key1": 1, "key2": 2},
111 )
113 assert not dc_eq(instance1, instance2)
116def test_dc_eq_case4():
117 @dataclass(eq=False)
118 class TestClass:
119 a: int
120 b: np.ndarray
121 c: torch.Tensor
122 e: list[int]
123 f: dict[str, int]
125 @dataclass(eq=False)
126 class TestClass2:
127 a: int
128 b: np.ndarray
129 c: torch.Tensor
130 e: list[int]
131 f: dict[str, int]
133 instance1 = TestClass(
134 a=1,
135 b=np.array([1, 2, 3]),
136 c=torch.tensor([1, 2, 3]),
137 e=[1, 2, 3],
138 f={"key1": 1, "key2": 2},
139 )
141 instance2 = TestClass2(
142 a=1,
143 b=np.array([1, 2, 3]),
144 c=torch.tensor([1, 2, 3]),
145 e=[1, 2, 3],
146 f={"key1": 1, "key2": 2},
147 )
149 assert not dc_eq(instance1, instance2)
152def test_dc_eq_case5():
153 @dataclass(eq=False)
154 class TestClass:
155 a: int
157 @dataclass(eq=False)
158 class TestClass2:
159 a: int
161 instance1 = TestClass(a=1)
163 instance2 = TestClass2(a=1)
165 assert not dc_eq(instance1, instance2)
168def test_dc_eq_case6():
169 @dataclass(eq=False)
170 class TestClass:
171 pass
173 @dataclass(eq=False)
174 class TestClass2:
175 pass
177 instance1 = TestClass()
179 instance2 = TestClass2()
181 assert not dc_eq(instance1, instance2)
184def test_dc_eq_case7():
185 @dataclass
186 class TestClass:
187 pass
189 @dataclass
190 class TestClass2:
191 pass
193 instance1 = TestClass()
195 instance2 = TestClass2()
197 assert not dc_eq(instance1, instance2)