Coverage for tests/unit/test_dbg.py: 98%
116 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 16:45 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 16:45 -0600
1import inspect
2import tempfile
3from pathlib import Path
4import importlib
5from typing import Any, Callable, Optional, List, Tuple
7import pytest
9from muutils.dbg import (
10 dbg,
11 _NoExpPassed,
12 _process_path,
13 _CWD,
14 # we do use this as a global in `test_dbg_counter_increments`
15 _COUNTER, # noqa: F401
16)
19DBG_MODULE_NAME: str = "muutils.dbg"
21# ============================================================================
22# Dummy Tensor classes for testing tensor_info* functions
23# ============================================================================
26class DummyTensor:
27 """A dummy tensor whose sum is NaN."""
29 shape: Tuple[int, ...] = (2, 3)
30 dtype: str = "float32"
31 device: str = "cpu"
32 requires_grad: bool = False
34 def sum(self) -> float:
35 return float("nan")
38class DummyTensorNormal:
39 """A dummy tensor with a normal sum."""
41 shape: Tuple[int, ...] = (4, 5)
42 dtype: str = "int32"
43 device: str = "cuda"
44 requires_grad: bool = True
46 def sum(self) -> float:
47 return 20.0
50class DummyTensorPartial:
51 """A dummy tensor with only a shape attribute."""
53 shape: Tuple[int, ...] = (7,)
56# ============================================================================
57# Additional Tests for dbg and tensor_info functionality
58# ============================================================================
61# --- Tests for _process_path (existing ones) ---
62def test_process_path_absolute(monkeypatch: pytest.MonkeyPatch) -> None:
63 monkeypatch.setattr(
64 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "absolute"
65 )
66 test_path: Path = Path("somefile.txt")
67 expected: str = test_path.absolute().as_posix()
68 result: str = _process_path(test_path)
69 assert result == expected
72def test_process_path_relative_inside_common(monkeypatch: pytest.MonkeyPatch) -> None:
73 monkeypatch.setattr(
74 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative"
75 )
76 test_path: Path = _CWD / "file.txt"
77 expected: str = "file.txt"
78 result: str = _process_path(test_path)
79 assert result == expected
82def test_process_path_relative_outside_common(monkeypatch: pytest.MonkeyPatch) -> None:
83 monkeypatch.setattr(
84 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative"
85 )
86 with tempfile.TemporaryDirectory() as tmp_dir:
87 test_path: Path = Path(tmp_dir) / "file.txt"
88 expected: str = test_path.absolute().as_posix()
89 result: str = _process_path(test_path)
90 assert result == expected
93def test_process_path_invalid_mode(monkeypatch: pytest.MonkeyPatch) -> None:
94 monkeypatch.setattr(
95 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "invalid"
96 )
97 with pytest.raises(
98 ValueError, match="PATH_MODE must be either 'relative' or 'absolute"
99 ):
100 _process_path(Path("anything.txt"))
103# --- Tests for dbg ---
104def test_dbg_with_expression(capsys: pytest.CaptureFixture) -> None:
105 result: int = dbg(1 + 2)
106 captured: str = capsys.readouterr().err
107 assert "= 3" in captured
108 # check that the printed string includes some form of "1+2"
109 assert "1+2" in captured.replace(" ", "") or "1 + 2" in captured
110 assert result == 3
113def test_dbg_without_expression(
114 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
115) -> None:
116 monkeypatch.setattr(importlib.import_module(DBG_MODULE_NAME), "_COUNTER", 0)
117 result: Any = dbg()
118 captured: str = capsys.readouterr().err.strip()
119 assert "<dbg 0>" in captured
120 no_exp_passed: Any = _NoExpPassed
121 assert result is no_exp_passed
124def test_dbg_custom_formatter(capsys: pytest.CaptureFixture) -> None:
125 custom_formatter: Callable[[Any], str] = lambda x: "custom" # noqa: E731
126 result: str = dbg("anything", formatter=custom_formatter)
127 captured: str = capsys.readouterr().err
128 assert "custom" in captured
129 assert result == "anything"
132def test_dbg_complex_expression(capsys: pytest.CaptureFixture) -> None:
133 # Test a complex expression (lambda invocation)
134 result: int = dbg((lambda x: x * x)(5))
135 captured: str = capsys.readouterr().err
136 assert (
137 "lambda" in captured
138 ) # expecting the extracted code snippet to include 'lambda'
139 assert "25" in captured # evaluated result is 25
140 assert result == 25
143def test_dbg_multiline_code_context(
144 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
145) -> None:
146 # Create a fake stack with two frames; the first frame does not contain "dbg",
147 # but the second does.
148 class FakeFrame:
149 def __init__(
150 self, code_context: Optional[List[str]], filename: str, lineno: int
151 ) -> None:
152 self.code_context = code_context
153 self.filename = filename
154 self.lineno = lineno
156 def fake_inspect_stack() -> List[Any]:
157 return [
158 FakeFrame(["not line"], "frame1.py", 20),
159 FakeFrame(["dbg(2+2)", "ignored line"], "frame2.py", 30),
160 ]
162 monkeypatch.setattr(inspect, "stack", fake_inspect_stack)
163 result: int = dbg(2 + 2)
164 captured: str = capsys.readouterr().err
165 print(captured)
166 assert "2+2" in captured
167 assert "4" in captured
168 assert result == 4
171def test_dbg_counter_increments(capsys: pytest.CaptureFixture) -> None:
172 global _COUNTER
173 _COUNTER = 0
174 dbg()
175 out1: str = capsys.readouterr().err
176 dbg()
177 out2: str = capsys.readouterr().err
178 assert "<dbg 0>" in out1
179 assert "<dbg 1>" in out2
182def test_dbg_formatter_exception() -> None:
183 def bad_formatter(x: Any) -> str:
184 raise ValueError("formatter error")
186 with pytest.raises(ValueError, match="formatter error"):
187 dbg(123, formatter=bad_formatter)
190def test_dbg_incomplete_expression(
191 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
192) -> None:
193 # Simulate a frame with an incomplete expression (no closing parenthesis)
194 class FakeFrame:
195 def __init__(
196 self, code_context: Optional[List[str]], filename: str, lineno: int
197 ) -> None:
198 self.code_context = code_context
199 self.filename = filename
200 self.lineno = lineno
202 def fake_inspect_stack() -> List[Any]:
203 return [FakeFrame(["dbg(42"], "fake_incomplete.py", 100)]
205 monkeypatch.setattr(inspect, "stack", fake_inspect_stack)
206 result: int = dbg(42)
207 captured: str = capsys.readouterr().err
208 # The extracted expression should be "42" (since there's no closing parenthesis)
209 assert "42" in captured
210 assert result == 42
213def test_dbg_non_callable_formatter() -> None:
214 with pytest.raises(TypeError):
215 dbg(42, formatter="not callable") # type: ignore
218# # --- Tests for tensor_info_dict and tensor_info ---
219# def test_tensor_info_dict_with_nan() -> None:
220# tensor: DummyTensor = DummyTensor()
221# info: Dict[str, str] = tensor_info_dict(tensor)
222# expected: Dict[str, str] = {
223# "shape": repr((2, 3)),
224# "sum": repr(float("nan")),
225# "dtype": repr("float32"),
226# "device": repr("cpu"),
227# "requires_grad": repr(False),
228# }
229# assert info == expected
232# def test_tensor_info_dict_normal() -> None:
233# tensor: DummyTensorNormal = DummyTensorNormal()
234# info: Dict[str, str] = tensor_info_dict(tensor)
235# expected: Dict[str, str] = {
236# "shape": repr((4, 5)),
237# "dtype": repr("int32"),
238# "device": repr("cuda"),
239# "requires_grad": repr(True),
240# }
241# assert info == expected
244# def test_tensor_info_dict_partial() -> None:
245# tensor: DummyTensorPartial = DummyTensorPartial()
246# info: Dict[str, str] = tensor_info_dict(tensor)
247# expected: Dict[str, str] = {"shape": repr((7,))}
248# assert info == expected
251# def test_tensor_info() -> None:
252# tensor: DummyTensorNormal = DummyTensorNormal()
253# info_str: str = tensor_info(tensor)
254# expected: str = ", ".join(
255# [
256# f"shape={repr((4, 5))}",
257# f"dtype={repr('int32')}",
258# f"device={repr('cuda')}",
259# f"requires_grad={repr(True)}",
260# ]
261# )
262# assert info_str == expected
265# def test_tensor_info_dict_no_attributes() -> None:
266# class DummyEmpty:
267# pass
269# dummy = DummyEmpty()
270# info: Dict[str, str] = tensor_info_dict(dummy)
271# assert info == {}
274# def test_tensor_info_no_attributes() -> None:
275# class DummyEmpty:
276# pass
278# dummy = DummyEmpty()
279# info_str: str = tensor_info(dummy)
280# assert info_str == ""
283# def test_dbg_tensor(capsys: pytest.CaptureFixture) -> None:
284# tensor: DummyTensorPartial = DummyTensorPartial()
285# result: DummyTensorPartial = dbg_tensor(tensor) # type: ignore
286# captured: str = capsys.readouterr().err
287# assert "shape=(7,)" in captured
288# assert result is tensor