diff --git a/python/e2b_code_interpreter/messaging.py b/python/e2b_code_interpreter/messaging.py index 5f5c732..9db096c 100644 --- a/python/e2b_code_interpreter/messaging.py +++ b/python/e2b_code_interpreter/messaging.py @@ -14,8 +14,7 @@ from e2b.utils.future import DeferredFuture from pydantic import ConfigDict, PrivateAttr, BaseModel -from e2b_code_interpreter.models import Cell, Error - +from e2b_code_interpreter.models import Cell, DisplayData, Error logger = logging.getLogger(__name__) @@ -170,7 +169,7 @@ def _receive_message(self, data: dict): result.error = Error( name=data["content"]["ename"], value=data["content"]["evalue"], - traceback=data["content"]["traceback"], + traceback_raw=data["content"]["traceback"], ) elif data["msg_type"] == "stream": @@ -196,9 +195,9 @@ def _receive_message(self, data: dict): ) elif data["msg_type"] in "display_data": - result.display_data.append(data["content"]["data"]) + result.display_data.append(DisplayData(**data["content"]["data"])) elif data["msg_type"] == "execute_result": - result.result = data["content"]["data"] + result.result = DisplayData(**data["content"]["data"]) elif data["msg_type"] == "status": if data["content"]["execution_state"] == "idle": if cell.input_accepted: @@ -210,7 +209,7 @@ def _receive_message(self, data: dict): result.error = Error( name=data["content"]["ename"], value=data["content"]["evalue"], - traceback=data["content"]["traceback"], + traceback_raw=data["content"]["traceback"], ) cell.result.set_result(result) @@ -220,7 +219,7 @@ def _receive_message(self, data: dict): result.error = Error( name=data["content"]["ename"], value=data["content"]["evalue"], - traceback=data["content"]["traceback"], + traceback_raw=data["content"]["traceback"], ) elif data["content"]["status"] == "ok": pass diff --git a/python/e2b_code_interpreter/models.py b/python/e2b_code_interpreter/models.py index 4b565ec..27cc686 100644 --- a/python/e2b_code_interpreter/models.py +++ b/python/e2b_code_interpreter/models.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import List, Optional from pydantic import BaseModel @@ -13,19 +13,110 @@ class Error(BaseModel): name: str value: str - traceback: List[str] + traceback_raw: List[str] + @property + def traceback(self) -> str: + """ + Returns the traceback as a single string. + + :return: The traceback as a single string. + """ + return "\n".join(self.traceback_raw) -MIMEType = str -DisplayData = Dict[MIMEType, str] -""" -Represents the data to be displayed as a result of executing a cell in a Jupyter notebook. -Dictionary that maps MIME types to their corresponding string representations of the data. -MIME types are used to specify the nature and format of the data, allowing for the representation -of various types of content such as text, images, and more. Each key in the interface is a MIME type -string, and its value is the data associated with that MIME type, formatted as a string. -""" +class DisplayData(dict): + """ + Represents the data to be displayed as a result of executing a cell in a Jupyter notebook. + + Dictionary that maps MIME types to their corresponding string representations of the data. + MIME types are used to specify the nature and format of the data, allowing for the representation + of various types of content such as text, images, and more. Each key in the interface is a MIME type + string, and its value is the data associated with that MIME type, formatted as a string. + """ + + def __init__(self, *args, **kwargs: str): + super().__init__(*args, **kwargs) + + def __str__(self): + """ + Returns the text representation of the data. + + :return: The text representation of the data. + """ + return self["text/plain"] + + def _repr_html_(self): + """ + Returns the HTML representation of the data. + + :return: The HTML representation of the data. + """ + return self.get("text/html", None) + + def _repr_markdown_(self): + """ + Returns the Markdown representation of the data. + + :return: The Markdown representation of the data. + """ + return self.get("text/markdown", None) + + def _repr_svg_(self): + """ + Returns the SVG representation of the data. + + :return: The SVG representation of the data. + """ + return self.get("image/svg+xml", None) + + def _repr_png_(self): + """ + Returns the PNG representation of the data. + + :return: The PNG representation of the data. + """ + return self.get("image/png", None) + + def _repr_jpeg_(self): + """ + Returns the JPEG representation of the data. + + :return: The JPEG representation of the data. + """ + return self.get("image/jpeg", None) + + def _repr_pdf_(self): + """ + Returns the PDF representation of the data. + + :return: The PDF representation of the data. + """ + return self.get("application/pdf", None) + + def _repr_latex_(self): + """ + Returns the LaTeX representation of the data. + + :return: The LaTeX representation of the data. + """ + return self.get("text/latex", None) + + def _repr_json_(self): + """ + Returns the JSON representation of the data. + + :return: The JSON representation of the data. + """ + return self.get("application/json", None) + + def _repr_javascript_(self): + """ + Returns the JavaScript representation of the data. + + :return: The JavaScript representation of the data. + """ + return self.get("application/javascript", None) class Cell(BaseModel): @@ -39,6 +130,9 @@ class Cell(BaseModel): error: an Error object if an error occurred, None otherwise. """ + class Config: + arbitrary_types_allowed = True + result: DisplayData = {} display_data: List[DisplayData] = [] stdout: List[str] = [] @@ -52,7 +146,7 @@ def text(self) -> str: :return: The text representation of the result. """ - return self.result.get("text/plain", None) + return self.result["text/plain"] class KernelException(Exception): diff --git a/python/example.py b/python/example.py index 9dec4b7..f1674c5 100644 --- a/python/example.py +++ b/python/example.py @@ -1,5 +1,3 @@ -import logging - from e2b_code_interpreter.main import CodeInterpreter from dotenv import load_dotenv