Coverage for muutils/logger/logger.py: 74%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"""logger with streams & levels, and a timer context manager 

2 

3- `SimpleLogger` is an extremely simple logger that can write to both console and a file 

4- `Logger` class handles levels in a slightly different way than default python `logging`, 

5 and also has "streams" which allow for different sorts of output in the same logger 

6 this was mostly made with training models in mind and storing both metadata and loss 

7- `TimerContext` is a context manager that can be used to time the duration of a block of code 

8""" 

9 

10from __future__ import annotations 

11 

12import json 

13import time 

14import typing 

15from functools import partial 

16from typing import Callable, Sequence 

17 

18from muutils.json_serialize import JSONitem, json_serialize 

19from muutils.logger.exception_context import ExceptionContext 

20from muutils.logger.headerfuncs import HEADER_FUNCTIONS, HeaderFunction 

21from muutils.logger.loggingstream import LoggingStream 

22from muutils.logger.simplelogger import AnyIO, SimpleLogger 

23 

24# pylint: disable=arguments-differ, bad-indentation, trailing-whitespace, trailing-newlines, unnecessary-pass, consider-using-with, use-dict-literal 

25 

26 

27def decode_level(level: int) -> str: 

28 if not isinstance(level, int): 

29 raise TypeError(f"level must be int, got {type(level) = } {level = }") 

30 

31 if level < -255: 

32 return f"FATAL_ERROR({level})" 

33 elif level < 0: 

34 return f"WARNING({level})" 

35 else: 

36 return f"INFO({level})" 

37 

38 

39# todo: add a context which catches and logs all exceptions 

40class Logger(SimpleLogger): 

41 """logger with more features, including log levels and streams 

42 

43 # Parameters: 

44 - `log_path : str | None` 

45 default log file path 

46 (defaults to `None`) 

47 - `log_file : AnyIO | None` 

48 default log io, should have a `.write()` method (pass only this or `log_path`, not both) 

49 (defaults to `None`) 

50 - `timestamp : bool` 

51 whether to add timestamps to every log message (under the `_timestamp` key) 

52 (defaults to `True`) 

53 - `default_level : int` 

54 default log level for streams/messages that don't specify a level 

55 (defaults to `0`) 

56 - `console_print_threshold : int` 

57 log level at which to print to the console, anything greater will not be printed unless overridden by `console_print` 

58 (defaults to `50`) 

59 - `level_header : HeaderFunction` 

60 function for formatting log messages when printing to console 

61 (defaults to `HEADER_FUNCTIONS["md"]`) 

62 - `keep_last_msg_time : bool` 

63 whether to keep the last message time 

64 (defaults to `True`) 

65 

66 

67 # Raises: 

68 - `ValueError` : _description_ 

69 """ 

70 

71 def __init__( 

72 self, 

73 log_path: str | None = None, 

74 log_file: AnyIO | None = None, 

75 default_level: int = 0, 

76 console_print_threshold: int = 50, 

77 level_header: HeaderFunction = HEADER_FUNCTIONS["md"], 

78 streams: dict[str | None, LoggingStream] | Sequence[LoggingStream] = (), 

79 keep_last_msg_time: bool = True, 

80 # junk args 

81 timestamp: bool = True, 

82 **kwargs, 

83 ): 

84 # junk arg checking 

85 # ================================================== 

86 if len(kwargs) > 0: 

87 raise ValueError(f"unrecognized kwargs: {kwargs}") 

88 

89 if not timestamp: 

90 raise ValueError( 

91 "timestamp must be True -- why would you not want timestamps?" 

92 ) 

93 

94 # timing 

95 # ================================================== 

96 # timing compares 

97 self._keep_last_msg_time: bool = keep_last_msg_time 

98 # TODO: handle per stream? 

99 self._last_msg_time: float | None = time.time() 

100 

101 # basic setup 

102 # ================================================== 

103 # init BaseLogger 

104 super().__init__(log_file=log_file, log_path=log_path, timestamp=timestamp) 

105 

106 # level-related 

107 self._console_print_threshold: int = console_print_threshold 

108 self._default_level: int = default_level 

109 

110 # set up streams 

111 self._streams: dict[str | None, LoggingStream] = ( 

112 streams 

113 if isinstance(streams, typing.Mapping) 

114 else {s.name: s for s in streams} 

115 ) 

116 # default error stream 

117 if "error" not in self._streams: 

118 self._streams["error"] = LoggingStream( 

119 "error", 

120 aliases={ 

121 "err", 

122 "except", 

123 "Exception", 

124 "exception", 

125 "exceptions", 

126 "errors", 

127 }, 

128 ) 

129 

130 # check alias duplicates 

131 alias_set: set[str | None] = set() 

132 for stream in self._streams.values(): 

133 for alias in stream.aliases: 

134 if alias in alias_set: 

135 raise ValueError(f"alias {alias} is already in use") 

136 alias_set.add(alias) 

137 

138 # add aliases 

139 for stream in tuple(self._streams.values()): 

140 for alias in stream.aliases: 

141 if alias not in self._streams: 

142 self._streams[alias] = stream 

143 

144 # print formatting 

145 self._level_header: HeaderFunction = level_header 

146 

147 print({k: str(v) for k, v in self._streams.items()}) 

148 

149 def _exception_context( 

150 self, 

151 stream: str = "error", 

152 # level: int = -256, 

153 # **kwargs, 

154 ) -> ExceptionContext: 

155 s: LoggingStream = self._streams[stream] 

156 return ExceptionContext(stream=s) 

157 

158 def log( # type: ignore # yes, the signatures are different here. 

159 self, 

160 msg: JSONitem = None, 

161 lvl: int | None = None, 

162 stream: str | None = None, 

163 console_print: bool = False, 

164 extra_indent: str = "", 

165 **kwargs, 

166 ): 

167 """logging function 

168 

169 ### Parameters: 

170 - `msg : JSONitem` 

171 message (usually string or dict) to be logged 

172 - `lvl : int | None` 

173 level of message (lower levels are more important) 

174 (defaults to `None`) 

175 - `console_print : bool` 

176 override `console_print_threshold` setting 

177 (defaults to `False`) 

178 - `stream : str | None` 

179 whether to log to a stream (defaults to `None`), which logs to the default `None` stream 

180 (defaults to `None`) 

181 """ 

182 

183 # add to known stream names if not present 

184 if stream not in self._streams: 

185 self._streams[stream] = LoggingStream(stream) 

186 

187 # set default level to either global or stream-specific default level 

188 # ======================================== 

189 if lvl is None: 

190 if stream is None: 

191 lvl = self._default_level 

192 else: 

193 if self._streams[stream].default_level is not None: 

194 lvl = self._streams[stream].default_level 

195 else: 

196 lvl = self._default_level 

197 

198 assert lvl is not None, "lvl should not be None at this point" 

199 

200 # print to console with formatting 

201 # ======================================== 

202 _printed: bool = False 

203 if console_print or (lvl <= self._console_print_threshold): 

204 # add some formatting 

205 print( 

206 self._level_header( 

207 msg=msg, 

208 lvl=lvl, 

209 stream=stream, 

210 extra_indent=extra_indent, 

211 ) 

212 ) 

213 

214 # store the last message time 

215 if self._last_msg_time is not None: 

216 self._last_msg_time = time.time() 

217 

218 _printed = True 

219 

220 # convert and add data 

221 # ======================================== 

222 # converting to dict 

223 msg_dict: typing.Mapping 

224 if not isinstance(msg, typing.Mapping): 

225 msg_dict = {"_msg": msg} 

226 else: 

227 msg_dict = msg 

228 

229 # level+stream metadata 

230 if lvl is not None: 

231 msg_dict["_lvl"] = lvl 

232 

233 # msg_dict["_stream"] = stream # moved to LoggingStream 

234 

235 # extra data in kwargs 

236 if len(kwargs) > 0: 

237 msg_dict["_kwargs"] = kwargs 

238 

239 # add default contents (timing, etc) 

240 msg_dict = { 

241 **{k: v() for k, v in self._streams[stream].default_contents.items()}, 

242 **msg_dict, 

243 } 

244 

245 # write 

246 # ======================================== 

247 logfile_msg: str = json.dumps(json_serialize(msg_dict)) + "\n" 

248 if ( 

249 (stream is None) 

250 or (stream not in self._streams) 

251 or (self._streams[stream].handler is None) 

252 ): 

253 # write to the main log file if no stream is specified 

254 self._log_file_handle.write(logfile_msg) 

255 else: 

256 # otherwise, write to the stream-specific file 

257 s_handler: AnyIO | None = self._streams[stream].handler 

258 if s_handler is not None: 

259 s_handler.write(logfile_msg) 

260 else: 

261 raise ValueError( 

262 f"stream handler is None! something in the logging stream setup is wrong:\n{self}" 

263 ) 

264 

265 # if it was important enough to print, flush all streams 

266 if _printed: 

267 self.flush_all() 

268 

269 def log_elapsed_last( 

270 self, 

271 lvl: int | None = None, 

272 stream: str | None = None, 

273 console_print: bool = True, 

274 **kwargs, 

275 ) -> float: 

276 """logs the time elapsed since the last message was printed to the console (in any stream)""" 

277 if self._last_msg_time is None: 

278 raise ValueError("no last message time!") 

279 else: 

280 return self.log( 

281 {"elapsed_time": round(time.time() - self._last_msg_time, 6)}, 

282 lvl=(lvl if lvl is not None else self._console_print_threshold), 

283 stream=stream, 

284 console_print=console_print, 

285 **kwargs, 

286 ) 

287 

288 def flush_all(self): 

289 """flush all streams""" 

290 

291 self._log_file_handle.flush() 

292 

293 for stream in self._streams.values(): 

294 if stream.handler is not None: 

295 stream.handler.flush() 

296 

297 def __getattr__(self, stream: str) -> Callable: 

298 if stream.startswith("_"): 

299 raise AttributeError(f"invalid stream name {stream} (no underscores)") 

300 return partial(self.log, stream=stream) 

301 

302 def __getitem__(self, stream: str): 

303 return partial(self.log, stream=stream) 

304 

305 def __call__(self, *args, **kwargs): 

306 return self.log(*args, **kwargs)