Coverage for tests/unit/test_mlutils.py: 86%
43 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1import sys
2from pathlib import Path
4from muutils.mlutils import get_checkpoint_paths_for_run, register_method
7def test_get_checkpoint_paths_for_run():
8 run_path = Path("tests/_temp/test_get_checkpoint_paths")
9 run_path.mkdir(parents=True, exist_ok=True)
10 checkpoints_path = run_path / "checkpoints"
11 checkpoint1_path = checkpoints_path / "model.iter_123.pt"
12 checkpoint2_path = checkpoints_path / "model.iter_456.pt"
13 other_path = checkpoints_path / "other_file.txt"
15 checkpoints_path.mkdir(exist_ok=True)
16 checkpoint1_path.touch()
17 checkpoint2_path.touch()
18 other_path.touch()
20 checkpoint_paths = get_checkpoint_paths_for_run(run_path, "pt")
22 assert checkpoint_paths == [(123, checkpoint1_path), (456, checkpoint2_path)]
25BELOW_PY_3_10: bool = sys.version_info < (3, 10)
28def test_register_method(recwarn):
29 class TestEvalsA:
30 evals: dict = {}
32 @register_method(evals)
33 @staticmethod
34 def eval_function():
35 pass
37 @staticmethod
38 def other_function():
39 pass
41 class TestEvalsB:
42 evals: dict = {}
44 @register_method(evals)
45 @staticmethod
46 def other_eval_function():
47 pass
49 if BELOW_PY_3_10:
50 assert len(recwarn) == 2
51 else:
52 assert len(recwarn) == 0
54 evalsA = TestEvalsA.evals
55 evalsB = TestEvalsB.evals
56 if BELOW_PY_3_10:
57 assert len(evalsA) == 1
58 assert len(evalsB) == 1
59 else:
60 assert list(evalsA.keys()) == ["eval_function"]
61 assert list(evalsB.keys()) == ["other_eval_function"]