diff --git a/mloggers/file_logger.py b/mloggers/file_logger.py index 0b1276a..c3ad964 100644 --- a/mloggers/file_logger.py +++ b/mloggers/file_logger.py @@ -3,12 +3,12 @@ from datetime import datetime from typing import Any -import numpy as np import numpy.typing as npt from termcolor import colored from mloggers._log_levels import LogLevel from mloggers.logger import Logger +from mloggers.utils import serialize class FileLogger(Logger): @@ -75,16 +75,16 @@ def log( else: message = messages[0] - # Convert numpy's ndarrays to lists so that they are JSON serializable - if isinstance(message, np.ndarray): - message = message.tolist() - if isinstance(message, dict): - for key, value in message.items(): - if isinstance(value, np.ndarray): - message[key] = value.tolist() - elif hasattr(message, "__str__") and callable(getattr(message, "__str__")): - message = str(message) + # JSON-serialize the message + try: + message = serialize(message) + except TypeError as e: + print( + f'{colored("[ERROR]", "red")} [FileLogger] Could not convert the message to a JSON serializable format: {e}' + ) + return + # Read the existing logs try: with open(self._file_path, "r") as file: existing_content = file.read() @@ -109,8 +109,8 @@ def log( ) return + # Create the new log and write it to the file new_logs = prev_logs.copy() - try: log: dict[str, Any] = { "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), diff --git a/mloggers/utils.py b/mloggers/utils.py new file mode 100644 index 0000000..a48c008 --- /dev/null +++ b/mloggers/utils.py @@ -0,0 +1,45 @@ +import jsonpickle +import numpy as np + + +def serialize(message: object) -> object: + """ + Serializes the message to a JSON serializable format. + + ### Parameters + ---------- + `message`: The message to serialize. + + ### Returns + ---------- + The serialized message. + + ### Raises + ---------- + - `TypeError`: If the message cannot be serialized. + """ + + if isinstance(message, np.ndarray): + return message.tolist() + elif isinstance(message, dict): + for key, value in message.items(): + message[key] = serialize(value) + return message + elif hasattr(message, "toJSON"): + return message.toJSON() # type:ignore[reportAttributeAccessIssue] + elif hasattr(message, "to_json"): + return message.to_json() # type:ignore[reportAttributeAccessIssue] + elif hasattr(message, "to_dict"): + return message.to_dict() # type:ignore[reportAttributeAccessIssue] + elif hasattr(message, "__str__") and callable(getattr(message, "__str__")): + return str(message) + else: + try: + return dict(message) # type:ignore[reportCallIssue, reportArgumentType] + except Exception: + try: + return jsonpickle.encode(message) + except Exception as e: + raise TypeError( + f"Could not serialize the message: {message}. Error: {e}" + ) diff --git a/pyproject.toml b/pyproject.toml index 018fd18..14eb271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mloggers" -version = "1.3.1" +version = "1.3.2" authors = [ { name = "Sergio Hernandez Gutierrez", email = "contact.sergiohernandez@gmail.com" }, { name = "Matteo Merler", email = "matteo.merler@gmail.com" }, @@ -17,7 +17,15 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["aenum", "numpy", "omegaconf", "termcolor", "wandb", "rich"] +dependencies = [ + "aenum", + "jsonpickle", + "numpy", + "omegaconf", + "termcolor", + "wandb", + "rich", +] [project.urls] Homepage = "https://github.com/serhez/mloggers"