Skip to content

Commit

Permalink
feat(file_logger): attempt to JSON-serialize objects
Browse files Browse the repository at this point in the history
  • Loading branch information
serhez committed May 16, 2024
1 parent c8198be commit c6f3c86
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 13 deletions.
22 changes: 11 additions & 11 deletions mloggers/file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from datetime import datetime
from typing import Any

import numpy as np
import numpy.typing as npt
from termcolor import colored

from mloggers._log_levels import LogLevel
from mloggers.logger import Logger
from mloggers.utils import serialize


class FileLogger(Logger):
Expand Down Expand Up @@ -75,16 +75,16 @@ def log(
else:
message = messages[0]

# Convert numpy's ndarrays to lists so that they are JSON serializable
if isinstance(message, np.ndarray):
message = message.tolist()
if isinstance(message, dict):
for key, value in message.items():
if isinstance(value, np.ndarray):
message[key] = value.tolist()
elif hasattr(message, "__str__") and callable(getattr(message, "__str__")):
message = str(message)
# JSON-serialize the message
try:
message = serialize(message)
except TypeError as e:
print(
f'{colored("[ERROR]", "red")} [FileLogger] Could not convert the message to a JSON serializable format: {e}'
)
return

# Read the existing logs
try:
with open(self._file_path, "r") as file:
existing_content = file.read()
Expand All @@ -109,8 +109,8 @@ def log(
)
return

# Create the new log and write it to the file
new_logs = prev_logs.copy()

try:
log: dict[str, Any] = {
"timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
Expand Down
45 changes: 45 additions & 0 deletions mloggers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import jsonpickle
import numpy as np


def serialize(message: object) -> object:
"""
Serializes the message to a JSON serializable format.
### Parameters
----------
`message`: The message to serialize.
### Returns
----------
The serialized message.
### Raises
----------
- `TypeError`: If the message cannot be serialized.
"""

if isinstance(message, np.ndarray):
return message.tolist()
elif isinstance(message, dict):
for key, value in message.items():
message[key] = serialize(value)
return message
elif hasattr(message, "toJSON"):
return message.toJSON() # type:ignore[reportAttributeAccessIssue]
elif hasattr(message, "to_json"):
return message.to_json() # type:ignore[reportAttributeAccessIssue]
elif hasattr(message, "to_dict"):
return message.to_dict() # type:ignore[reportAttributeAccessIssue]
elif hasattr(message, "__str__") and callable(getattr(message, "__str__")):
return str(message)
else:
try:
return dict(message) # type:ignore[reportCallIssue, reportArgumentType]
except Exception:
try:
return jsonpickle.encode(message)
except Exception as e:
raise TypeError(
f"Could not serialize the message: {message}. Error: {e}"
)
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "mloggers"
version = "1.3.1"
version = "1.3.2"
authors = [
{ name = "Sergio Hernandez Gutierrez", email = "[email protected]" },
{ name = "Matteo Merler", email = "[email protected]" },
Expand All @@ -17,7 +17,15 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = ["aenum", "numpy", "omegaconf", "termcolor", "wandb", "rich"]
dependencies = [
"aenum",
"jsonpickle",
"numpy",
"omegaconf",
"termcolor",
"wandb",
"rich",
]

[project.urls]
Homepage = "https://github.com/serhez/mloggers"
Expand Down

0 comments on commit c6f3c86

Please sign in to comment.