Coverage for tests/unit/test_mlutils.py: 86%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1import sys 

2from pathlib import Path 

3 

4from muutils.mlutils import get_checkpoint_paths_for_run, register_method 

5 

6 

7def test_get_checkpoint_paths_for_run(): 

8 run_path = Path("tests/_temp/test_get_checkpoint_paths") 

9 run_path.mkdir(parents=True, exist_ok=True) 

10 checkpoints_path = run_path / "checkpoints" 

11 checkpoint1_path = checkpoints_path / "model.iter_123.pt" 

12 checkpoint2_path = checkpoints_path / "model.iter_456.pt" 

13 other_path = checkpoints_path / "other_file.txt" 

14 

15 checkpoints_path.mkdir(exist_ok=True) 

16 checkpoint1_path.touch() 

17 checkpoint2_path.touch() 

18 other_path.touch() 

19 

20 checkpoint_paths = get_checkpoint_paths_for_run(run_path, "pt") 

21 

22 assert checkpoint_paths == [(123, checkpoint1_path), (456, checkpoint2_path)] 

23 

24 

25BELOW_PY_3_10: bool = sys.version_info < (3, 10) 

26 

27 

28def test_register_method(recwarn): 

29 class TestEvalsA: 

30 evals: dict = {} 

31 

32 @register_method(evals) 

33 @staticmethod 

34 def eval_function(): 

35 pass 

36 

37 @staticmethod 

38 def other_function(): 

39 pass 

40 

41 class TestEvalsB: 

42 evals: dict = {} 

43 

44 @register_method(evals) 

45 @staticmethod 

46 def other_eval_function(): 

47 pass 

48 

49 if BELOW_PY_3_10: 

50 assert len(recwarn) == 2 

51 else: 

52 assert len(recwarn) == 0 

53 

54 evalsA = TestEvalsA.evals 

55 evalsB = TestEvalsB.evals 

56 if BELOW_PY_3_10: 

57 assert len(evalsA) == 1 

58 assert len(evalsB) == 1 

59 else: 

60 assert list(evalsA.keys()) == ["eval_function"] 

61 assert list(evalsB.keys()) == ["other_eval_function"]