Coverage for muutils / logger / logger.py: 73%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 22:24 -0600

1"""logger with streams & levels, and a timer context manager 

2 

3- `SimpleLogger` is an extremely simple logger that can write to both console and a file 

4- `Logger` class handles levels in a slightly different way than default python `logging`, 

5 and also has "streams" which allow for different sorts of output in the same logger 

6 this was mostly made with training models in mind and storing both metadata and loss 

7- `TimerContext` is a context manager that can be used to time the duration of a block of code 

8""" 

9 

10from __future__ import annotations 

11 

12import json 

13import time 

14import typing 

15from functools import partial 

16from typing import Any, Callable, Sequence 

17 

18from muutils.json_serialize import JSONitem, json_serialize 

19from muutils.logger.exception_context import ExceptionContext 

20from muutils.logger.headerfuncs import HEADER_FUNCTIONS, HeaderFunction 

21from muutils.logger.loggingstream import LoggingStream 

22from muutils.logger.simplelogger import AnyIO, SimpleLogger 

23 

24# pylint: disable=arguments-differ, bad-indentation, trailing-whitespace, trailing-newlines, unnecessary-pass, consider-using-with, use-dict-literal 

25 

26 

27def decode_level(level: int) -> str: 

28 if not isinstance(level, int): 

29 raise TypeError(f"level must be int, got {type(level) = } {level = }") 

30 

31 if level < -255: 

32 return f"FATAL_ERROR({level})" 

33 elif level < 0: 

34 return f"WARNING({level})" 

35 else: 

36 return f"INFO({level})" 

37 

38 

39# todo: add a context which catches and logs all exceptions 

40class Logger(SimpleLogger): 

41 """logger with more features, including log levels and streams 

42 

43 # Parameters: 

44 - `log_path : str | None` 

45 default log file path 

46 (defaults to `None`) 

47 - `log_file : AnyIO | None` 

48 default log io, should have a `.write()` method (pass only this or `log_path`, not both) 

49 (defaults to `None`) 

50 - `timestamp : bool` 

51 whether to add timestamps to every log message (under the `_timestamp` key) 

52 (defaults to `True`) 

53 - `default_level : int` 

54 default log level for streams/messages that don't specify a level 

55 (defaults to `0`) 

56 - `console_print_threshold : int` 

57 log level at which to print to the console, anything greater will not be printed unless overridden by `console_print` 

58 (defaults to `50`) 

59 - `level_header : HeaderFunction` 

60 function for formatting log messages when printing to console 

61 (defaults to `HEADER_FUNCTIONS["md"]`) 

62 - `keep_last_msg_time : bool` 

63 whether to keep the last message time 

64 (defaults to `True`) 

65 

66 

67 # Raises: 

68 - `ValueError` : _description_ 

69 """ 

70 

71 def __init__( 

72 self, 

73 log_path: str | None = None, 

74 log_file: AnyIO | None = None, 

75 default_level: int = 0, 

76 console_print_threshold: int = 50, 

77 level_header: HeaderFunction = HEADER_FUNCTIONS["md"], 

78 streams: dict[str | None, LoggingStream] | Sequence[LoggingStream] = (), 

79 keep_last_msg_time: bool = True, 

80 # junk args 

81 timestamp: bool = True, 

82 **kwargs: Any, 

83 ) -> None: 

84 # junk arg checking 

85 # ================================================== 

86 if len(kwargs) > 0: 

87 raise ValueError(f"unrecognized kwargs: {kwargs}") 

88 

89 if not timestamp: 

90 raise ValueError( 

91 "timestamp must be True -- why would you not want timestamps?" 

92 ) 

93 

94 # timing 

95 # ================================================== 

96 # timing compares 

97 self._keep_last_msg_time: bool = keep_last_msg_time 

98 # TODO: handle per stream? 

99 self._last_msg_time: float | None = time.time() 

100 

101 # basic setup 

102 # ================================================== 

103 # init BaseLogger 

104 super().__init__(log_file=log_file, log_path=log_path, timestamp=timestamp) 

105 

106 # level-related 

107 self._console_print_threshold: int = console_print_threshold 

108 self._default_level: int = default_level 

109 

110 # set up streams 

111 self._streams: dict[str | None, LoggingStream] = ( 

112 streams if isinstance(streams, dict) else {s.name: s for s in streams} # ty: ignore[invalid-assignment] 

113 ) 

114 # default error stream 

115 if "error" not in self._streams: 

116 self._streams["error"] = LoggingStream( 

117 "error", 

118 aliases={ 

119 "err", 

120 "except", 

121 "Exception", 

122 "exception", 

123 "exceptions", 

124 "errors", 

125 }, 

126 ) 

127 

128 # check alias duplicates 

129 alias_set: set[str | None] = set() 

130 for stream in self._streams.values(): 

131 for alias in stream.aliases: 

132 if alias in alias_set: 

133 raise ValueError(f"alias {alias} is already in use") 

134 alias_set.add(alias) 

135 

136 # add aliases 

137 for stream in tuple(self._streams.values()): 

138 for alias in stream.aliases: 

139 if alias not in self._streams: 

140 self._streams[alias] = stream 

141 

142 # print formatting 

143 self._level_header: HeaderFunction = level_header 

144 

145 print({k: str(v) for k, v in self._streams.items()}) 

146 

147 def _exception_context( 

148 self, 

149 stream: str = "error", 

150 # level: int = -256, 

151 # **kwargs, 

152 ) -> ExceptionContext: 

153 import sys 

154 

155 s: LoggingStream = self._streams[stream] 

156 handler = s.handler if s.handler is not None else sys.stderr 

157 return ExceptionContext(stream=handler) 

158 

159 def log( 

160 self, 

161 msg: JSONitem = None, 

162 *, 

163 lvl: int | None = None, 

164 stream: str | None = None, 

165 console_print: bool = False, 

166 extra_indent: str = "", 

167 **kwargs: Any, 

168 ) -> None: 

169 """logging function 

170 

171 ### Parameters: 

172 - `msg : JSONitem` 

173 message (usually string or dict) to be logged 

174 - `lvl : int | None` 

175 level of message (lower levels are more important) 

176 (defaults to `None`) 

177 - `console_print : bool` 

178 override `console_print_threshold` setting 

179 (defaults to `False`) 

180 - `stream : str | None` 

181 whether to log to a stream (defaults to `None`), which logs to the default `None` stream 

182 (defaults to `None`) 

183 """ 

184 

185 # add to known stream names if not present 

186 if stream not in self._streams: 

187 self._streams[stream] = LoggingStream(stream) 

188 

189 # set default level to either global or stream-specific default level 

190 # ======================================== 

191 if lvl is None: 

192 if stream is None: 

193 lvl = self._default_level 

194 else: 

195 if self._streams[stream].default_level is not None: 

196 lvl = self._streams[stream].default_level 

197 else: 

198 lvl = self._default_level 

199 

200 assert lvl is not None, "lvl should not be None at this point" 

201 

202 # print to console with formatting 

203 # ======================================== 

204 _printed: bool = False 

205 if console_print or (lvl <= self._console_print_threshold): 

206 # add some formatting 

207 print( 

208 self._level_header( 

209 msg=msg, 

210 lvl=lvl, 

211 stream=stream, 

212 extra_indent=extra_indent, 

213 ) 

214 ) 

215 

216 # store the last message time 

217 if self._last_msg_time is not None: 

218 self._last_msg_time = time.time() 

219 

220 _printed = True 

221 

222 # convert and add data 

223 # ======================================== 

224 # converting to dict 

225 msg_dict: dict[str, Any] 

226 if not isinstance(msg, typing.Mapping): 

227 msg_dict = {"_msg": msg} 

228 else: 

229 msg_dict = dict(typing.cast(typing.Mapping[str, Any], msg)) 

230 

231 # level+stream metadata 

232 if lvl is not None: 

233 msg_dict["_lvl"] = lvl 

234 

235 # msg_dict["_stream"] = stream # moved to LoggingStream 

236 

237 # extra data in kwargs 

238 if len(kwargs) > 0: 

239 msg_dict["_kwargs"] = kwargs 

240 

241 # add default contents (timing, etc) 

242 msg_dict = { 

243 **{k: v() for k, v in self._streams[stream].default_contents.items()}, 

244 **msg_dict, 

245 } 

246 

247 # write 

248 # ======================================== 

249 logfile_msg: str = json.dumps(json_serialize(msg_dict)) + "\n" 

250 if ( 

251 (stream is None) 

252 or (stream not in self._streams) 

253 or (self._streams[stream].handler is None) 

254 ): 

255 # write to the main log file if no stream is specified 

256 self._log_file_handle.write(logfile_msg) 

257 else: 

258 # otherwise, write to the stream-specific file 

259 s_handler: AnyIO | None = self._streams[stream].handler 

260 if s_handler is not None: 

261 s_handler.write(logfile_msg) 

262 else: 

263 raise ValueError( 

264 f"stream handler is None! something in the logging stream setup is wrong:\n{self}" 

265 ) 

266 

267 # if it was important enough to print, flush all streams 

268 if _printed: 

269 self.flush_all() 

270 

271 def log_elapsed_last( 

272 self, 

273 lvl: int | None = None, 

274 stream: str | None = None, 

275 console_print: bool = True, 

276 **kwargs: Any, 

277 ) -> None: 

278 """logs the time elapsed since the last message was printed to the console (in any stream)""" 

279 if self._last_msg_time is None: 

280 raise ValueError("no last message time!") 

281 else: 

282 self.log( 

283 {"elapsed_time": round(time.time() - self._last_msg_time, 6)}, 

284 lvl=(lvl if lvl is not None else self._console_print_threshold), 

285 stream=stream, 

286 console_print=console_print, 

287 **kwargs, 

288 ) 

289 

290 def flush_all(self): 

291 """flush all streams""" 

292 

293 self._log_file_handle.flush() 

294 

295 for stream in self._streams.values(): 

296 if stream.handler is not None: 

297 stream.handler.flush() 

298 

299 def __getattr__(self, stream: str) -> Callable[..., Any]: 

300 if stream.startswith("_"): 

301 raise AttributeError(f"invalid stream name {stream} (no underscores)") 

302 return partial(self.log, stream=stream) 

303 

304 def __getitem__(self, stream: str) -> Callable[..., Any]: 

305 return partial(self.log, stream=stream) 

306 

307 def __call__(self, *args: Any, **kwargs: Any) -> None: 

308 self.log(*args, **kwargs)