Coverage for muutils / logger / logger.py: 73%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 02:51 -0700

1"""logger with streams & levels, and a timer context manager 

2 

3- `SimpleLogger` is an extremely simple logger that can write to both console and a file 

4- `Logger` class handles levels in a slightly different way than default python `logging`, 

5 and also has "streams" which allow for different sorts of output in the same logger 

6 this was mostly made with training models in mind and storing both metadata and loss 

7- `TimerContext` is a context manager that can be used to time the duration of a block of code 

8""" 

9 

10from __future__ import annotations 

11 

12import json 

13import time 

14import typing 

15from functools import partial 

16from typing import Any, Callable, Sequence 

17 

18from muutils.json_serialize import JSONitem, json_serialize 

19from muutils.logger.exception_context import ExceptionContext 

20from muutils.logger.headerfuncs import HEADER_FUNCTIONS, HeaderFunction 

21from muutils.logger.loggingstream import LoggingStream 

22from muutils.logger.simplelogger import AnyIO, SimpleLogger 

23 

24# pylint: disable=arguments-differ, bad-indentation, trailing-whitespace, trailing-newlines, unnecessary-pass, consider-using-with, use-dict-literal 

25 

26 

27def decode_level(level: int) -> str: 

28 if not isinstance(level, int): 

29 raise TypeError(f"level must be int, got {type(level) = } {level = }") 

30 

31 if level < -255: 

32 return f"FATAL_ERROR({level})" 

33 elif level < 0: 

34 return f"WARNING({level})" 

35 else: 

36 return f"INFO({level})" 

37 

38 

39# todo: add a context which catches and logs all exceptions 

40class Logger(SimpleLogger): 

41 """logger with more features, including log levels and streams 

42 

43 # Parameters: 

44 - `log_path : str | None` 

45 default log file path 

46 (defaults to `None`) 

47 - `log_file : AnyIO | None` 

48 default log io, should have a `.write()` method (pass only this or `log_path`, not both) 

49 (defaults to `None`) 

50 - `timestamp : bool` 

51 whether to add timestamps to every log message (under the `_timestamp` key) 

52 (defaults to `True`) 

53 - `default_level : int` 

54 default log level for streams/messages that don't specify a level 

55 (defaults to `0`) 

56 - `console_print_threshold : int` 

57 log level at which to print to the console, anything greater will not be printed unless overridden by `console_print` 

58 (defaults to `50`) 

59 - `level_header : HeaderFunction` 

60 function for formatting log messages when printing to console 

61 (defaults to `HEADER_FUNCTIONS["md"]`) 

62 - `keep_last_msg_time : bool` 

63 whether to keep the last message time 

64 (defaults to `True`) 

65 

66 

67 # Raises: 

68 - `ValueError` : _description_ 

69 """ 

70 

71 def __init__( 

72 self, 

73 log_path: str | None = None, 

74 log_file: AnyIO | None = None, 

75 default_level: int = 0, 

76 console_print_threshold: int = 50, 

77 level_header: HeaderFunction = HEADER_FUNCTIONS["md"], 

78 streams: dict[str | None, LoggingStream] | Sequence[LoggingStream] = (), 

79 keep_last_msg_time: bool = True, 

80 # junk args 

81 timestamp: bool = True, 

82 **kwargs: Any, 

83 ) -> None: 

84 # junk arg checking 

85 # ================================================== 

86 if len(kwargs) > 0: 

87 raise ValueError(f"unrecognized kwargs: {kwargs}") 

88 

89 if not timestamp: 

90 raise ValueError( 

91 "timestamp must be True -- why would you not want timestamps?" 

92 ) 

93 

94 # timing 

95 # ================================================== 

96 # timing compares 

97 self._keep_last_msg_time: bool = keep_last_msg_time 

98 # TODO: handle per stream? 

99 self._last_msg_time: float | None = time.time() 

100 

101 # basic setup 

102 # ================================================== 

103 # init BaseLogger 

104 super().__init__(log_file=log_file, log_path=log_path, timestamp=timestamp) 

105 

106 # level-related 

107 self._console_print_threshold: int = console_print_threshold 

108 self._default_level: int = default_level 

109 

110 # set up streams 

111 self._streams: dict[str | None, LoggingStream] = ( 

112 streams 

113 if isinstance(streams, typing.Mapping) 

114 else {s.name: s for s in streams} 

115 ) 

116 # default error stream 

117 if "error" not in self._streams: 

118 self._streams["error"] = LoggingStream( 

119 "error", 

120 aliases={ 

121 "err", 

122 "except", 

123 "Exception", 

124 "exception", 

125 "exceptions", 

126 "errors", 

127 }, 

128 ) 

129 

130 # check alias duplicates 

131 alias_set: set[str | None] = set() 

132 for stream in self._streams.values(): 

133 for alias in stream.aliases: 

134 if alias in alias_set: 

135 raise ValueError(f"alias {alias} is already in use") 

136 alias_set.add(alias) 

137 

138 # add aliases 

139 for stream in tuple(self._streams.values()): 

140 for alias in stream.aliases: 

141 if alias not in self._streams: 

142 self._streams[alias] = stream 

143 

144 # print formatting 

145 self._level_header: HeaderFunction = level_header 

146 

147 print({k: str(v) for k, v in self._streams.items()}) 

148 

149 def _exception_context( 

150 self, 

151 stream: str = "error", 

152 # level: int = -256, 

153 # **kwargs, 

154 ) -> ExceptionContext: 

155 import sys 

156 

157 s: LoggingStream = self._streams[stream] 

158 handler = s.handler if s.handler is not None else sys.stderr 

159 return ExceptionContext(stream=handler) 

160 

161 def log( 

162 self, 

163 msg: JSONitem = None, 

164 *, 

165 lvl: int | None = None, 

166 stream: str | None = None, 

167 console_print: bool = False, 

168 extra_indent: str = "", 

169 **kwargs: Any, 

170 ) -> None: 

171 """logging function 

172 

173 ### Parameters: 

174 - `msg : JSONitem` 

175 message (usually string or dict) to be logged 

176 - `lvl : int | None` 

177 level of message (lower levels are more important) 

178 (defaults to `None`) 

179 - `console_print : bool` 

180 override `console_print_threshold` setting 

181 (defaults to `False`) 

182 - `stream : str | None` 

183 whether to log to a stream (defaults to `None`), which logs to the default `None` stream 

184 (defaults to `None`) 

185 """ 

186 

187 # add to known stream names if not present 

188 if stream not in self._streams: 

189 self._streams[stream] = LoggingStream(stream) 

190 

191 # set default level to either global or stream-specific default level 

192 # ======================================== 

193 if lvl is None: 

194 if stream is None: 

195 lvl = self._default_level 

196 else: 

197 if self._streams[stream].default_level is not None: 

198 lvl = self._streams[stream].default_level 

199 else: 

200 lvl = self._default_level 

201 

202 assert lvl is not None, "lvl should not be None at this point" 

203 

204 # print to console with formatting 

205 # ======================================== 

206 _printed: bool = False 

207 if console_print or (lvl <= self._console_print_threshold): 

208 # add some formatting 

209 print( 

210 self._level_header( 

211 msg=msg, 

212 lvl=lvl, 

213 stream=stream, 

214 extra_indent=extra_indent, 

215 ) 

216 ) 

217 

218 # store the last message time 

219 if self._last_msg_time is not None: 

220 self._last_msg_time = time.time() 

221 

222 _printed = True 

223 

224 # convert and add data 

225 # ======================================== 

226 # converting to dict 

227 msg_dict: dict[str, Any] 

228 if not isinstance(msg, typing.Mapping): 

229 msg_dict = {"_msg": msg} 

230 else: 

231 msg_dict = dict(typing.cast(typing.Mapping[str, Any], msg)) 

232 

233 # level+stream metadata 

234 if lvl is not None: 

235 msg_dict["_lvl"] = lvl 

236 

237 # msg_dict["_stream"] = stream # moved to LoggingStream 

238 

239 # extra data in kwargs 

240 if len(kwargs) > 0: 

241 msg_dict["_kwargs"] = kwargs 

242 

243 # add default contents (timing, etc) 

244 msg_dict = { 

245 **{k: v() for k, v in self._streams[stream].default_contents.items()}, 

246 **msg_dict, 

247 } 

248 

249 # write 

250 # ======================================== 

251 logfile_msg: str = json.dumps(json_serialize(msg_dict)) + "\n" 

252 if ( 

253 (stream is None) 

254 or (stream not in self._streams) 

255 or (self._streams[stream].handler is None) 

256 ): 

257 # write to the main log file if no stream is specified 

258 self._log_file_handle.write(logfile_msg) 

259 else: 

260 # otherwise, write to the stream-specific file 

261 s_handler: AnyIO | None = self._streams[stream].handler 

262 if s_handler is not None: 

263 s_handler.write(logfile_msg) 

264 else: 

265 raise ValueError( 

266 f"stream handler is None! something in the logging stream setup is wrong:\n{self}" 

267 ) 

268 

269 # if it was important enough to print, flush all streams 

270 if _printed: 

271 self.flush_all() 

272 

273 def log_elapsed_last( 

274 self, 

275 lvl: int | None = None, 

276 stream: str | None = None, 

277 console_print: bool = True, 

278 **kwargs: Any, 

279 ) -> None: 

280 """logs the time elapsed since the last message was printed to the console (in any stream)""" 

281 if self._last_msg_time is None: 

282 raise ValueError("no last message time!") 

283 else: 

284 self.log( 

285 {"elapsed_time": round(time.time() - self._last_msg_time, 6)}, 

286 lvl=(lvl if lvl is not None else self._console_print_threshold), 

287 stream=stream, 

288 console_print=console_print, 

289 **kwargs, 

290 ) 

291 

292 def flush_all(self): 

293 """flush all streams""" 

294 

295 self._log_file_handle.flush() 

296 

297 for stream in self._streams.values(): 

298 if stream.handler is not None: 

299 stream.handler.flush() 

300 

301 def __getattr__(self, stream: str) -> Callable[..., Any]: 

302 if stream.startswith("_"): 

303 raise AttributeError(f"invalid stream name {stream} (no underscores)") 

304 return partial(self.log, stream=stream) 

305 

306 def __getitem__(self, stream: str) -> Callable[..., Any]: 

307 return partial(self.log, stream=stream) 

308 

309 def __call__(self, *args: Any, **kwargs: Any) -> None: 

310 self.log(*args, **kwargs)