Coverage for tests/unit/test_dbg.py: 98%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 16:45 -0600

1import inspect 

2import tempfile 

3from pathlib import Path 

4import importlib 

5from typing import Any, Callable, Optional, List, Tuple 

6 

7import pytest 

8 

9from muutils.dbg import ( 

10 dbg, 

11 _NoExpPassed, 

12 _process_path, 

13 _CWD, 

14 # we do use this as a global in `test_dbg_counter_increments` 

15 _COUNTER, # noqa: F401 

16) 

17 

18 

19DBG_MODULE_NAME: str = "muutils.dbg" 

20 

21# ============================================================================ 

22# Dummy Tensor classes for testing tensor_info* functions 

23# ============================================================================ 

24 

25 

26class DummyTensor: 

27 """A dummy tensor whose sum is NaN.""" 

28 

29 shape: Tuple[int, ...] = (2, 3) 

30 dtype: str = "float32" 

31 device: str = "cpu" 

32 requires_grad: bool = False 

33 

34 def sum(self) -> float: 

35 return float("nan") 

36 

37 

38class DummyTensorNormal: 

39 """A dummy tensor with a normal sum.""" 

40 

41 shape: Tuple[int, ...] = (4, 5) 

42 dtype: str = "int32" 

43 device: str = "cuda" 

44 requires_grad: bool = True 

45 

46 def sum(self) -> float: 

47 return 20.0 

48 

49 

50class DummyTensorPartial: 

51 """A dummy tensor with only a shape attribute.""" 

52 

53 shape: Tuple[int, ...] = (7,) 

54 

55 

56# ============================================================================ 

57# Additional Tests for dbg and tensor_info functionality 

58# ============================================================================ 

59 

60 

61# --- Tests for _process_path (existing ones) --- 

62def test_process_path_absolute(monkeypatch: pytest.MonkeyPatch) -> None: 

63 monkeypatch.setattr( 

64 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "absolute" 

65 ) 

66 test_path: Path = Path("somefile.txt") 

67 expected: str = test_path.absolute().as_posix() 

68 result: str = _process_path(test_path) 

69 assert result == expected 

70 

71 

72def test_process_path_relative_inside_common(monkeypatch: pytest.MonkeyPatch) -> None: 

73 monkeypatch.setattr( 

74 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative" 

75 ) 

76 test_path: Path = _CWD / "file.txt" 

77 expected: str = "file.txt" 

78 result: str = _process_path(test_path) 

79 assert result == expected 

80 

81 

82def test_process_path_relative_outside_common(monkeypatch: pytest.MonkeyPatch) -> None: 

83 monkeypatch.setattr( 

84 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative" 

85 ) 

86 with tempfile.TemporaryDirectory() as tmp_dir: 

87 test_path: Path = Path(tmp_dir) / "file.txt" 

88 expected: str = test_path.absolute().as_posix() 

89 result: str = _process_path(test_path) 

90 assert result == expected 

91 

92 

93def test_process_path_invalid_mode(monkeypatch: pytest.MonkeyPatch) -> None: 

94 monkeypatch.setattr( 

95 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "invalid" 

96 ) 

97 with pytest.raises( 

98 ValueError, match="PATH_MODE must be either 'relative' or 'absolute" 

99 ): 

100 _process_path(Path("anything.txt")) 

101 

102 

103# --- Tests for dbg --- 

104def test_dbg_with_expression(capsys: pytest.CaptureFixture) -> None: 

105 result: int = dbg(1 + 2) 

106 captured: str = capsys.readouterr().err 

107 assert "= 3" in captured 

108 # check that the printed string includes some form of "1+2" 

109 assert "1+2" in captured.replace(" ", "") or "1 + 2" in captured 

110 assert result == 3 

111 

112 

113def test_dbg_without_expression( 

114 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture 

115) -> None: 

116 monkeypatch.setattr(importlib.import_module(DBG_MODULE_NAME), "_COUNTER", 0) 

117 result: Any = dbg() 

118 captured: str = capsys.readouterr().err.strip() 

119 assert "<dbg 0>" in captured 

120 no_exp_passed: Any = _NoExpPassed 

121 assert result is no_exp_passed 

122 

123 

124def test_dbg_custom_formatter(capsys: pytest.CaptureFixture) -> None: 

125 custom_formatter: Callable[[Any], str] = lambda x: "custom" # noqa: E731 

126 result: str = dbg("anything", formatter=custom_formatter) 

127 captured: str = capsys.readouterr().err 

128 assert "custom" in captured 

129 assert result == "anything" 

130 

131 

132def test_dbg_complex_expression(capsys: pytest.CaptureFixture) -> None: 

133 # Test a complex expression (lambda invocation) 

134 result: int = dbg((lambda x: x * x)(5)) 

135 captured: str = capsys.readouterr().err 

136 assert ( 

137 "lambda" in captured 

138 ) # expecting the extracted code snippet to include 'lambda' 

139 assert "25" in captured # evaluated result is 25 

140 assert result == 25 

141 

142 

143def test_dbg_multiline_code_context( 

144 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture 

145) -> None: 

146 # Create a fake stack with two frames; the first frame does not contain "dbg", 

147 # but the second does. 

148 class FakeFrame: 

149 def __init__( 

150 self, code_context: Optional[List[str]], filename: str, lineno: int 

151 ) -> None: 

152 self.code_context = code_context 

153 self.filename = filename 

154 self.lineno = lineno 

155 

156 def fake_inspect_stack() -> List[Any]: 

157 return [ 

158 FakeFrame(["not line"], "frame1.py", 20), 

159 FakeFrame(["dbg(2+2)", "ignored line"], "frame2.py", 30), 

160 ] 

161 

162 monkeypatch.setattr(inspect, "stack", fake_inspect_stack) 

163 result: int = dbg(2 + 2) 

164 captured: str = capsys.readouterr().err 

165 print(captured) 

166 assert "2+2" in captured 

167 assert "4" in captured 

168 assert result == 4 

169 

170 

171def test_dbg_counter_increments(capsys: pytest.CaptureFixture) -> None: 

172 global _COUNTER 

173 _COUNTER = 0 

174 dbg() 

175 out1: str = capsys.readouterr().err 

176 dbg() 

177 out2: str = capsys.readouterr().err 

178 assert "<dbg 0>" in out1 

179 assert "<dbg 1>" in out2 

180 

181 

182def test_dbg_formatter_exception() -> None: 

183 def bad_formatter(x: Any) -> str: 

184 raise ValueError("formatter error") 

185 

186 with pytest.raises(ValueError, match="formatter error"): 

187 dbg(123, formatter=bad_formatter) 

188 

189 

190def test_dbg_incomplete_expression( 

191 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture 

192) -> None: 

193 # Simulate a frame with an incomplete expression (no closing parenthesis) 

194 class FakeFrame: 

195 def __init__( 

196 self, code_context: Optional[List[str]], filename: str, lineno: int 

197 ) -> None: 

198 self.code_context = code_context 

199 self.filename = filename 

200 self.lineno = lineno 

201 

202 def fake_inspect_stack() -> List[Any]: 

203 return [FakeFrame(["dbg(42"], "fake_incomplete.py", 100)] 

204 

205 monkeypatch.setattr(inspect, "stack", fake_inspect_stack) 

206 result: int = dbg(42) 

207 captured: str = capsys.readouterr().err 

208 # The extracted expression should be "42" (since there's no closing parenthesis) 

209 assert "42" in captured 

210 assert result == 42 

211 

212 

213def test_dbg_non_callable_formatter() -> None: 

214 with pytest.raises(TypeError): 

215 dbg(42, formatter="not callable") # type: ignore 

216 

217 

218# # --- Tests for tensor_info_dict and tensor_info --- 

219# def test_tensor_info_dict_with_nan() -> None: 

220# tensor: DummyTensor = DummyTensor() 

221# info: Dict[str, str] = tensor_info_dict(tensor) 

222# expected: Dict[str, str] = { 

223# "shape": repr((2, 3)), 

224# "sum": repr(float("nan")), 

225# "dtype": repr("float32"), 

226# "device": repr("cpu"), 

227# "requires_grad": repr(False), 

228# } 

229# assert info == expected 

230 

231 

232# def test_tensor_info_dict_normal() -> None: 

233# tensor: DummyTensorNormal = DummyTensorNormal() 

234# info: Dict[str, str] = tensor_info_dict(tensor) 

235# expected: Dict[str, str] = { 

236# "shape": repr((4, 5)), 

237# "dtype": repr("int32"), 

238# "device": repr("cuda"), 

239# "requires_grad": repr(True), 

240# } 

241# assert info == expected 

242 

243 

244# def test_tensor_info_dict_partial() -> None: 

245# tensor: DummyTensorPartial = DummyTensorPartial() 

246# info: Dict[str, str] = tensor_info_dict(tensor) 

247# expected: Dict[str, str] = {"shape": repr((7,))} 

248# assert info == expected 

249 

250 

251# def test_tensor_info() -> None: 

252# tensor: DummyTensorNormal = DummyTensorNormal() 

253# info_str: str = tensor_info(tensor) 

254# expected: str = ", ".join( 

255# [ 

256# f"shape={repr((4, 5))}", 

257# f"dtype={repr('int32')}", 

258# f"device={repr('cuda')}", 

259# f"requires_grad={repr(True)}", 

260# ] 

261# ) 

262# assert info_str == expected 

263 

264 

265# def test_tensor_info_dict_no_attributes() -> None: 

266# class DummyEmpty: 

267# pass 

268 

269# dummy = DummyEmpty() 

270# info: Dict[str, str] = tensor_info_dict(dummy) 

271# assert info == {} 

272 

273 

274# def test_tensor_info_no_attributes() -> None: 

275# class DummyEmpty: 

276# pass 

277 

278# dummy = DummyEmpty() 

279# info_str: str = tensor_info(dummy) 

280# assert info_str == "" 

281 

282 

283# def test_dbg_tensor(capsys: pytest.CaptureFixture) -> None: 

284# tensor: DummyTensorPartial = DummyTensorPartial() 

285# result: DummyTensorPartial = dbg_tensor(tensor) # type: ignore 

286# captured: str = capsys.readouterr().err 

287# assert "shape=(7,)" in captured 

288# assert result is tensor