Coverage for tests / unit / test_collect_warnings.py: 99%

256 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1from __future__ import annotations 

2 

3import sys 

4import warnings 

5from io import StringIO 

6 

7import pytest 

8 

9from muutils.collect_warnings import CollateWarnings 

10 

11 

12def test_basic_warning_capture(): 

13 """Test that warnings issued inside the context populate the counts dict.""" 

14 with CollateWarnings(print_on_exit=False) as cw: 

15 warnings.warn("test warning 1", UserWarning) 

16 warnings.warn("test warning 2", DeprecationWarning) 

17 

18 assert len(cw.counts) == 2 

19 

20 # Check that the warnings are in the counts dict 

21 warning_messages = [msg for (_, _, _, msg) in cw.counts.keys()] 

22 assert "test warning 1" in warning_messages 

23 assert "test warning 2" in warning_messages 

24 

25 # Check that the category names are correct 

26 categories = [cat for (_, _, cat, _) in cw.counts.keys()] 

27 assert "UserWarning" in categories 

28 assert "DeprecationWarning" in categories 

29 

30 # Check that counts are 1 for each 

31 assert all(count == 1 for count in cw.counts.values()) 

32 

33 

34def test_collation(): 

35 """Test that duplicate warnings from the same line increment count correctly.""" 

36 with CollateWarnings(print_on_exit=False) as cw: 

37 # Issue the same warning multiple times from a loop (same line) 

38 for _ in range(3): 

39 warnings.warn("duplicate warning", UserWarning) 

40 warnings.warn("different warning", UserWarning) 

41 

42 # The duplicate warnings from the same line should be collated 

43 # Find the duplicate warning entry 

44 duplicate_count = None 

45 different_count = None 

46 for (filename, lineno, category, message), count in cw.counts.items(): 

47 if message == "duplicate warning": 

48 duplicate_count = count 

49 elif message == "different warning": 

50 different_count = count 

51 

52 assert duplicate_count == 3 

53 assert different_count == 1 

54 

55 

56def test_print_on_exit_true(): 

57 """Test that warnings are printed to stderr on exit when print_on_exit=True.""" 

58 # Capture stderr 

59 old_stderr = sys.stderr 

60 sys.stderr = StringIO() 

61 

62 try: 

63 with CollateWarnings(print_on_exit=True) as cw: 

64 warnings.warn("printed warning", UserWarning) 

65 

66 assert cw 

67 

68 # Get the output 

69 stderr_output = sys.stderr.getvalue() 

70 

71 # Check that the warning was printed 

72 assert "printed warning" in stderr_output 

73 assert "UserWarning" in stderr_output 

74 assert "(1x)" in stderr_output # Default format includes count 

75 

76 finally: 

77 # Restore stderr 

78 sys.stderr = old_stderr 

79 

80 

81def test_print_on_exit_false(): 

82 """Test that no output is produced but counts are tracked when print_on_exit=False.""" 

83 # Capture stderr 

84 old_stderr = sys.stderr 

85 sys.stderr = StringIO() 

86 

87 try: 

88 with CollateWarnings(print_on_exit=False) as cw: 

89 warnings.warn("silent warning", UserWarning) 

90 

91 # Get the output 

92 stderr_output = sys.stderr.getvalue() 

93 

94 # Check that nothing was printed 

95 assert stderr_output == "" 

96 

97 # But counts should still be tracked 

98 assert len(cw.counts) == 1 

99 warning_messages = [msg for (_, _, _, msg) in cw.counts.keys()] 

100 assert "silent warning" in warning_messages 

101 

102 finally: 

103 # Restore stderr 

104 sys.stderr = old_stderr 

105 

106 

107def test_custom_format_string(): 

108 """Test that custom fmt parameter controls output format.""" 

109 # Capture stderr 

110 old_stderr = sys.stderr 

111 sys.stderr = StringIO() 

112 

113 try: 

114 custom_fmt = "WARNING: {message} ({category}) appeared {count} times" 

115 with CollateWarnings(print_on_exit=True, fmt=custom_fmt) as cw: 

116 warnings.warn("custom format warning", UserWarning) 

117 

118 assert cw 

119 

120 # Get the output 

121 stderr_output = sys.stderr.getvalue() 

122 

123 # Check that the custom format was used 

124 assert "WARNING: custom format warning" in stderr_output 

125 assert "(UserWarning)" in stderr_output 

126 assert "appeared 1 times" in stderr_output 

127 

128 # Check that default format was NOT used 

129 assert "(1x)" not in stderr_output 

130 

131 finally: 

132 # Restore stderr 

133 sys.stderr = old_stderr 

134 

135 

136def test_multiple_different_warnings(): 

137 """Test handling of multiple different warnings.""" 

138 with CollateWarnings(print_on_exit=False) as cw: 

139 warnings.warn("warning 1", UserWarning) 

140 warnings.warn("warning 2", DeprecationWarning) 

141 warnings.warn("warning 3", FutureWarning) 

142 warnings.warn("warning 4", RuntimeWarning) 

143 

144 assert len(cw.counts) == 4 

145 

146 categories = [cat for (_, _, cat, _) in cw.counts.keys()] 

147 assert "UserWarning" in categories 

148 assert "DeprecationWarning" in categories 

149 assert "FutureWarning" in categories 

150 assert "RuntimeWarning" in categories 

151 

152 

153def test_no_warnings(): 

154 """Test that CollateWarnings works correctly when no warnings are issued.""" 

155 with CollateWarnings(print_on_exit=False) as cw: 

156 # No warnings issued 

157 pass 

158 

159 assert len(cw.counts) == 0 

160 

161 

162def test_same_message_different_categories(): 

163 """Test that same message with different categories are counted separately.""" 

164 with CollateWarnings(print_on_exit=False) as cw: 

165 # Issue same message with different categories from the same line in a loop 

166 for _ in range(2): 

167 warnings.warn("same message", UserWarning) 

168 warnings.warn("same message", DeprecationWarning) 

169 

170 # Find the counts for each category 

171 user_warning_count = 0 

172 deprecation_warning_count = 0 

173 for (_, _, category, message), count in cw.counts.items(): 

174 if message == "same message" and category == "UserWarning": 

175 user_warning_count += count 

176 elif message == "same message" and category == "DeprecationWarning": 

177 deprecation_warning_count += count 

178 

179 assert user_warning_count == 2 

180 assert deprecation_warning_count == 1 

181 

182 

183def test_filename_and_lineno_tracking(): 

184 """Test that filename and line number are tracked correctly.""" 

185 with CollateWarnings(print_on_exit=False) as cw: 

186 warnings.warn("tracked warning", UserWarning) 

187 

188 assert len(cw.counts) == 1 

189 

190 # Get the filename and lineno 

191 (filename, lineno, category, message) = list(cw.counts.keys())[0] 

192 

193 # Check that filename and lineno are present and reasonable 

194 assert filename is not None 

195 assert isinstance(filename, str) 

196 assert lineno is not None 

197 assert isinstance(lineno, int) 

198 assert lineno > 0 

199 

200 

201def test_context_manager_re_entry_fails(): 

202 """Test that CollateWarnings cannot be re-entered while active.""" 

203 cw = CollateWarnings(print_on_exit=False) 

204 

205 with cw: 

206 # Try to re-enter while still inside the context 

207 with pytest.raises(RuntimeError, match="cannot be re-entered"): 

208 with cw: 

209 pass 

210 

211 

212def test_format_string_all_fields(): 

213 """Test that all format fields work correctly.""" 

214 old_stderr = sys.stderr 

215 sys.stderr = StringIO() 

216 

217 try: 

218 fmt = "count={count} file={filename} line={lineno} cat={category} msg={message}" 

219 with CollateWarnings(print_on_exit=True, fmt=fmt) as cw: 

220 warnings.warn("test all fields", UserWarning) 

221 

222 assert cw 

223 

224 stderr_output = sys.stderr.getvalue() 

225 

226 # Check that all fields are present 

227 assert "count=1" in stderr_output 

228 assert "file=" in stderr_output 

229 assert "line=" in stderr_output 

230 assert "cat=UserWarning" in stderr_output 

231 assert "msg=test all fields" in stderr_output 

232 

233 finally: 

234 sys.stderr = old_stderr 

235 

236 

237def test_warning_with_stacklevel(): 

238 """Test that warnings with different stacklevels are handled correctly.""" 

239 

240 def issue_warning(): 

241 warnings.warn("nested warning", UserWarning, stacklevel=2) 

242 

243 with CollateWarnings(print_on_exit=False) as cw: 

244 issue_warning() 

245 

246 assert len(cw.counts) == 1 

247 warning_messages = [msg for (_, _, _, msg) in cw.counts.keys()] 

248 assert "nested warning" in warning_messages 

249 

250 

251def test_counts_dict_structure(): 

252 """Test the structure of the counts dictionary.""" 

253 with CollateWarnings(print_on_exit=False) as cw: 

254 warnings.warn("test warning", UserWarning) 

255 

256 # Check that counts is a Counter 

257 from collections import Counter 

258 

259 assert isinstance(cw.counts, Counter) 

260 

261 # Check the key structure 

262 key = list(cw.counts.keys())[0] 

263 assert isinstance(key, tuple) 

264 assert len(key) == 4 

265 

266 filename, lineno, category, message = key 

267 assert isinstance(filename, str) 

268 assert isinstance(lineno, int) 

269 assert isinstance(category, str) 

270 assert isinstance(message, str) 

271 

272 

273def test_large_number_of_warnings(): 

274 """Test handling of a large number of duplicate warnings.""" 

275 with CollateWarnings(print_on_exit=False) as cw: 

276 for i in range(1000): 

277 warnings.warn("repeated warning", UserWarning) 

278 

279 assert len(cw.counts) == 1 

280 

281 # Find the count 

282 count = list(cw.counts.values())[0] 

283 assert count == 1000 

284 

285 

286def test_mixed_warning_counts(): 

287 """Test a mix of different warning counts.""" 

288 with CollateWarnings(print_on_exit=False) as cw: 

289 # Warning A: 5 times 

290 for _ in range(5): 

291 warnings.warn("warning A", UserWarning) 

292 

293 # Warning B: 3 times 

294 for _ in range(3): 

295 warnings.warn("warning B", DeprecationWarning) 

296 

297 # Warning C: 1 time 

298 warnings.warn("warning C", FutureWarning) 

299 

300 assert len(cw.counts) == 3 

301 

302 # Extract counts by message 

303 counts_by_message = {} 

304 for (_, _, _, message), count in cw.counts.items(): 

305 counts_by_message[message] = count 

306 

307 assert counts_by_message["warning A"] == 5 

308 assert counts_by_message["warning B"] == 3 

309 assert counts_by_message["warning C"] == 1 

310 

311 

312def test_exception_propagation(): 

313 """Test that exceptions from the with-block are propagated.""" 

314 with pytest.raises(ValueError, match="test exception"): 

315 with CollateWarnings(print_on_exit=False) as cw: 

316 warnings.warn("warning before exception", UserWarning) 

317 raise ValueError("test exception") 

318 

319 # Counts should still be populated even though an exception was raised 

320 assert len(cw.counts) == 1 # pyright: ignore[reportPossiblyUnboundVariable] 

321 

322 

323def test_warning_with_special_characters(): 

324 """Test warnings with special characters in messages.""" 

325 with CollateWarnings(print_on_exit=False) as cw: 

326 warnings.warn("warning with 'quotes' and \"double quotes\"", UserWarning) 

327 warnings.warn("warning with\nnewline", UserWarning) 

328 warnings.warn("warning with\ttab", UserWarning) 

329 

330 assert len(cw.counts) == 3 

331 

332 messages = [msg for (_, _, _, msg) in cw.counts.keys()] 

333 assert "warning with 'quotes' and \"double quotes\"" in messages 

334 assert "warning with\nnewline" in messages 

335 assert "warning with\ttab" in messages 

336 

337 

338def test_empty_warning_message(): 

339 """Test warning with empty message.""" 

340 with CollateWarnings(print_on_exit=False) as cw: 

341 warnings.warn("", UserWarning) 

342 

343 assert len(cw.counts) == 1 

344 messages = [msg for (_, _, _, msg) in cw.counts.keys()] 

345 assert "" in messages 

346 

347 

348def test_unicode_warning_message(): 

349 """Test warnings with unicode characters.""" 

350 with CollateWarnings(print_on_exit=False) as cw: 

351 warnings.warn("warning with unicode: 你好 мир 🌍", UserWarning) 

352 

353 assert len(cw.counts) == 1 

354 messages = [msg for (_, _, _, msg) in cw.counts.keys()] 

355 assert "warning with unicode: 你好 мир 🌍" in messages 

356 

357 

358def test_custom_warning_class(): 

359 """Test with custom warning classes.""" 

360 

361 class CustomWarning(UserWarning): 

362 pass 

363 

364 with CollateWarnings(print_on_exit=False) as cw: 

365 warnings.warn("custom warning", CustomWarning) 

366 

367 assert len(cw.counts) == 1 

368 categories = [cat for (_, _, cat, _) in cw.counts.keys()] 

369 assert "CustomWarning" in categories 

370 

371 

372def test_default_format_string(): 

373 """Test the default format string output.""" 

374 old_stderr = sys.stderr 

375 sys.stderr = StringIO() 

376 

377 try: 

378 with CollateWarnings(print_on_exit=True) as cw: 

379 warnings.warn("test default format", UserWarning) 

380 

381 assert cw 

382 

383 stderr_output = sys.stderr.getvalue().strip() 

384 

385 # Default format: "({count}x) {filename}:{lineno} {category}: {message}" 

386 assert stderr_output.startswith("(1x)") 

387 assert "UserWarning: test default format" in stderr_output 

388 assert ":" in stderr_output # filename:lineno separator 

389 

390 finally: 

391 sys.stderr = old_stderr 

392 

393 

394def test_collate_warnings_with_warnings_always(): 

395 """Test that warnings.simplefilter('always') is set correctly.""" 

396 # This test verifies that even if we would normally suppress duplicate warnings, 

397 # CollateWarnings captures them all 

398 with CollateWarnings(print_on_exit=False) as cw: 

399 # These would normally be suppressed if the same warning is issued twice 

400 # from the same location, but CollateWarnings should capture all of them 

401 for _ in range(3): 

402 warnings.warn("repeated warning", UserWarning) 

403 

404 # All 3 warnings should be captured 

405 count = list(cw.counts.values())[0] 

406 assert count == 3 

407 

408 

409def test_multiple_warnings_same_line(): 

410 """Test multiple different warnings from the same line.""" 

411 with CollateWarnings(print_on_exit=False) as cw: 

412 warnings.warn("warning 1", UserWarning) 

413 warnings.warn("warning 2", UserWarning) # noqa: E702 

414 

415 # Should have 2 different warnings (different messages, same line) 

416 assert len(cw.counts) == 2 

417 

418 

419def test_counts_accessible_after_exit(): 

420 """Test that counts are accessible after exiting the context.""" 

421 with CollateWarnings(print_on_exit=False) as cw: 

422 warnings.warn("test warning", UserWarning) 

423 

424 # After exiting, counts should still be accessible 

425 assert len(cw.counts) == 1 

426 assert cw.counts is not None 

427 

428 # Should be able to iterate over counts 

429 for key, count in cw.counts.items(): 

430 assert isinstance(key, tuple) 

431 assert isinstance(count, int) 

432 

433 

434def test_print_on_exit_default_true(): 

435 """Test that print_on_exit defaults to True.""" 

436 old_stderr = sys.stderr 

437 sys.stderr = StringIO() 

438 

439 try: 

440 # Don't specify print_on_exit, should default to True 

441 with CollateWarnings() as cw: 

442 warnings.warn("default print test", UserWarning) 

443 

444 assert cw 

445 stderr_output = sys.stderr.getvalue() 

446 assert "default print test" in stderr_output 

447 

448 finally: 

449 sys.stderr = old_stderr 

450 

451 

452def test_exit_twice_fails(): 

453 """Test that calling __exit__ twice raises RuntimeError.""" 

454 cw = CollateWarnings(print_on_exit=False) 

455 

456 # Enter the context 

457 cw.__enter__() 

458 

459 # Exit once 

460 cw.__exit__(None, None, None) 

461 

462 # Try to exit again - should raise RuntimeError 

463 with pytest.raises(RuntimeError, match="exited twice"): 

464 cw.__exit__(None, None, None)