From aec421b6a70bf63769dabf79b1b9c5d8bcc47e88 Mon Sep 17 00:00:00 2001 From: Gonzalo Mellizo-Soto Date: Thu, 6 Jun 2024 18:26:47 +0200 Subject: [PATCH 1/5] Add table printing and tests --- giza/cli/utils/output.py | 58 ++++++++++++++++++++++++++++++++++++++ tests/utils/test_output.py | 38 +++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 giza/cli/utils/output.py create mode 100644 tests/utils/test_output.py diff --git a/giza/cli/utils/output.py b/giza/cli/utils/output.py new file mode 100644 index 0000000..f06eb59 --- /dev/null +++ b/giza/cli/utils/output.py @@ -0,0 +1,58 @@ +from typing import Union + +from pydantic import BaseModel, RootModel +from rich.console import Console +from rich.table import Table + + +def extract_row(model: BaseModel) -> list: + """ + Extracts the row from a model + + Args: + model (BaseModel): A pydantic model which we extreact the fields value and set to "" if None + + Returns: + list: A list with the values of the fields + """ + result = [] + + for field in model.model_fields.keys(): + value = getattr(model, field, "") + if value is None: + value = "" + result.append(str(value)) + + return result + + +def print_model(model: Union[BaseModel, RootModel], title=""): + """ + Utility function to print a model or a list of models in a table for pretty printing + + Args: + model (Union[BaseModel, RootModel]): The model or list of models to print + title (str, optional): Title of the table. Defaults to "". + """ + + table = Table(title=title) + console = Console() + + # If its a root model we need to iterate over the root list and add a row for each model + # RootModel goes first as it is a subclass of BaseModel + if isinstance(model, RootModel): + # We pick the first model to get the fields + try: + for field in model.root[0].model_fields.keys(): + table.add_column(field) + except IndexError: + return + for m in model.root: + table.add_row(*extract_row(m)) + # If its a single model we just create a table with the fields, one single row + elif isinstance(model, BaseModel): + for field in model.model_fields.keys(): + table.add_column(field) + table.add_row(*extract_row(model)) + + console.print(table) diff --git a/tests/utils/test_output.py b/tests/utils/test_output.py new file mode 100644 index 0000000..c2d25db --- /dev/null +++ b/tests/utils/test_output.py @@ -0,0 +1,38 @@ +from giza.cli.schemas.models import Model, ModelList +from giza.cli.utils.output import extract_row, print_model + +model_one = Model( + id=1, + name="Model One", + description="This is model one", +) + +model_two = Model( + id=2, + name="Model Two", +) + +models = ModelList(root=[model_one, model_two]) + + +def test_print_single_model(capsys): + print_model(model_one) + captured = capsys.readouterr() + assert "Model One" in captured.out + assert "This is model one" in captured.out + + +def test_print_list_model(capsys): + print_model(models) + captured = capsys.readouterr() + assert "Model One" in captured.out + assert "Model Two" in captured.out + assert "This is model one" in captured.out + assert "" in captured.out + + +def test_extract_row(): + row = extract_row(model_one) + assert row == ["1", "Model One", "This is model one"] + row = extract_row(model_two) + assert row == ["2", "Model Two", ""] From 39436a8caa91c5183c80338d6c638ca3b66f8df2 Mon Sep 17 00:00:00 2001 From: Gonzalo Mellizo-Soto Date: Thu, 6 Jun 2024 18:38:58 +0200 Subject: [PATCH 2/5] Move output to Echo class --- giza/cli/utils/echo.py | 56 +++++++++++++++++++++++++++++++++++- giza/cli/utils/output.py | 58 -------------------------------------- tests/utils/test_echo.py | 40 ++++++++++++++++++++++++++ tests/utils/test_output.py | 38 ------------------------- 4 files changed, 95 insertions(+), 97 deletions(-) delete mode 100644 giza/cli/utils/output.py delete mode 100644 tests/utils/test_output.py diff --git a/giza/cli/utils/echo.py b/giza/cli/utils/echo.py index 738c20d..964ea3a 100644 --- a/giza/cli/utils/echo.py +++ b/giza/cli/utils/echo.py @@ -1,9 +1,12 @@ import datetime as dt -from typing import Optional +from typing import Optional, Union import typer +from pydantic import BaseModel, RootModel from rich import print as rich_print from rich import reconfigure +from rich.console import Console +from rich.table import Table reconfigure(soft_wrap=True) @@ -144,3 +147,54 @@ def __call__(self, message: str) -> None: message (str): message to formatt and print """ self.info(message) + + def _extract_row(self, model: BaseModel) -> list: + """ + Extracts the row from a model + + Args: + model (BaseModel): A pydantic model which we extreact the fields value and set to "" if None + + Returns: + list: A list with the values of the fields + """ + result = [] + + for field in model.model_fields.keys(): + value = getattr(model, field, "") + if value is None: + value = "" + result.append(str(value)) + + return result + + def print_model(self, model: Union[BaseModel, RootModel], title=""): + """ + Utility function to print a model or a list of models in a table for pretty printing + + Args: + model (Union[BaseModel, RootModel]): The model or list of models to print + title (str, optional): Title of the table. Defaults to "". + """ + + table = Table(title=title) + console = Console() + + # If its a root model we need to iterate over the root list and add a row for each model + # RootModel goes first as it is a subclass of BaseModel + if isinstance(model, RootModel): + # We pick the first model to get the fields + try: + for field in model.root[0].model_fields.keys(): + table.add_column(field) + except IndexError: + return + for m in model.root: + table.add_row(*self._extract_row(m)) + # If its a single model we just create a table with the fields, one single row + elif isinstance(model, BaseModel): + for field in model.model_fields.keys(): + table.add_column(field) + table.add_row(*self._extract_row(model)) + + console.print(table) diff --git a/giza/cli/utils/output.py b/giza/cli/utils/output.py deleted file mode 100644 index f06eb59..0000000 --- a/giza/cli/utils/output.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Union - -from pydantic import BaseModel, RootModel -from rich.console import Console -from rich.table import Table - - -def extract_row(model: BaseModel) -> list: - """ - Extracts the row from a model - - Args: - model (BaseModel): A pydantic model which we extreact the fields value and set to "" if None - - Returns: - list: A list with the values of the fields - """ - result = [] - - for field in model.model_fields.keys(): - value = getattr(model, field, "") - if value is None: - value = "" - result.append(str(value)) - - return result - - -def print_model(model: Union[BaseModel, RootModel], title=""): - """ - Utility function to print a model or a list of models in a table for pretty printing - - Args: - model (Union[BaseModel, RootModel]): The model or list of models to print - title (str, optional): Title of the table. Defaults to "". - """ - - table = Table(title=title) - console = Console() - - # If its a root model we need to iterate over the root list and add a row for each model - # RootModel goes first as it is a subclass of BaseModel - if isinstance(model, RootModel): - # We pick the first model to get the fields - try: - for field in model.root[0].model_fields.keys(): - table.add_column(field) - except IndexError: - return - for m in model.root: - table.add_row(*extract_row(m)) - # If its a single model we just create a table with the fields, one single row - elif isinstance(model, BaseModel): - for field in model.model_fields.keys(): - table.add_column(field) - table.add_row(*extract_row(model)) - - console.print(table) diff --git a/tests/utils/test_echo.py b/tests/utils/test_echo.py index 1f2b624..6d3ba19 100644 --- a/tests/utils/test_echo.py +++ b/tests/utils/test_echo.py @@ -2,8 +2,48 @@ import pytest +from giza.cli.schemas.models import Model, ModelList from giza.cli.utils.echo import Echo +model_one = Model( + id=1, + name="Model One", + description="This is model one", +) + +model_two = Model( + id=2, + name="Model Two", +) + +models = ModelList(root=[model_one, model_two]) + + +def test_print_single_model(capsys): + echo = Echo() + echo.print_model(model_one) + captured = capsys.readouterr() + assert "Model One" in captured.out + assert "This is model one" in captured.out + + +def test_print_list_model(capsys): + echo = Echo() + echo.print_model(models) + captured = capsys.readouterr() + assert "Model One" in captured.out + assert "Model Two" in captured.out + assert "This is model one" in captured.out + assert "" in captured.out + + +def test_extract_row(): + echo = Echo() + row = echo._extract_row(model_one) + assert row == ["1", "Model One", "This is model one"] + row = echo._extract_row(model_two) + assert row == ["2", "Model Two", ""] + def test_format_message(): """ diff --git a/tests/utils/test_output.py b/tests/utils/test_output.py deleted file mode 100644 index c2d25db..0000000 --- a/tests/utils/test_output.py +++ /dev/null @@ -1,38 +0,0 @@ -from giza.cli.schemas.models import Model, ModelList -from giza.cli.utils.output import extract_row, print_model - -model_one = Model( - id=1, - name="Model One", - description="This is model one", -) - -model_two = Model( - id=2, - name="Model Two", -) - -models = ModelList(root=[model_one, model_two]) - - -def test_print_single_model(capsys): - print_model(model_one) - captured = capsys.readouterr() - assert "Model One" in captured.out - assert "This is model one" in captured.out - - -def test_print_list_model(capsys): - print_model(models) - captured = capsys.readouterr() - assert "Model One" in captured.out - assert "Model Two" in captured.out - assert "This is model one" in captured.out - assert "" in captured.out - - -def test_extract_row(): - row = extract_row(model_one) - assert row == ["1", "Model One", "This is model one"] - row = extract_row(model_two) - assert row == ["2", "Model Two", ""] From b1aa8021d06654cd88bacb153602ecc6dfcd04f1 Mon Sep 17 00:00:00 2001 From: Gonzalo Mellizo-Soto Date: Thu, 6 Jun 2024 19:24:23 +0200 Subject: [PATCH 3/5] Add model printing to the commands --- giza/cli/commands/agents.py | 9 ++++----- giza/cli/commands/endpoints.py | 13 ++++++------- giza/cli/commands/models.py | 7 +++---- giza/cli/commands/users.py | 3 +-- giza/cli/commands/versions.py | 7 +++---- giza/cli/commands/workspaces.py | 3 +-- giza/cli/frameworks/cairo.py | 3 +-- giza/cli/frameworks/ezkl.py | 3 +-- giza/cli/utils/echo.py | 2 +- tests/commands/test_agents.py | 2 +- tests/commands/test_endpoints.py | 11 +++++++---- 11 files changed, 29 insertions(+), 34 deletions(-) diff --git a/giza/cli/commands/agents.py b/giza/cli/commands/agents.py index 7dc7090..5f002f7 100644 --- a/giza/cli/commands/agents.py +++ b/giza/cli/commands/agents.py @@ -3,7 +3,6 @@ from typing import List, Optional import typer -from rich import print_json from rich.console import Console from rich.table import Table @@ -104,7 +103,7 @@ def create( }, ) agent = client.create(agent_create) - print_json(agent.model_dump_json()) + echo.print_model(agent) @app.command( @@ -135,7 +134,7 @@ def list( else: query_params = None agents: AgentList = client.list(params=query_params) - print_json(agents.model_dump_json()) + echo.print_model(agents) # giza/commands/deployments.py @@ -155,7 +154,7 @@ def get( with ExceptionHandler(debug=debug): client = AgentsClient(API_HOST) deployment = client.get(agent_id) - print_json(deployment.model_dump_json()) + echo.print_model(deployment) @app.command( @@ -206,4 +205,4 @@ def update( name=name, description=description, parameters=update_params ) agent = client.patch(agent_id, agent_update) - print_json(agent.model_dump_json()) + echo.print_model(agent) diff --git a/giza/cli/commands/endpoints.py b/giza/cli/commands/endpoints.py index c60fff1..6d03dcc 100644 --- a/giza/cli/commands/endpoints.py +++ b/giza/cli/commands/endpoints.py @@ -4,7 +4,6 @@ import typer from pydantic import ValidationError from requests import HTTPError -from rich import print_json from giza.cli import API_HOST from giza.cli.client import EndpointsClient @@ -113,7 +112,7 @@ def list( if debug: raise e sys.exit(1) - print_json(deployments.model_dump_json()) + echo.print_model(deployments) # giza/commands/deployments.py @@ -156,7 +155,7 @@ def get( if debug: raise e sys.exit(1) - print_json(deployment.model_dump_json()) + echo.print_model(deployment) @app.command( @@ -220,7 +219,7 @@ def list_proofs( if debug: raise e sys.exit(1) - print_json(proofs.model_dump_json(exclude_unset=True)) + echo.print_model(proofs) @app.command( @@ -266,7 +265,7 @@ def get_proof( if debug: raise e sys.exit(1) - print_json(proof.model_dump_json(exclude_unset=True)) + echo.print_model(proof) @app.command( @@ -340,7 +339,7 @@ def list_jobs( with ExceptionHandler(debug=debug): client = EndpointsClient(API_HOST) jobs = client.list_jobs(endpoint_id) - print_json(jobs.json(exclude_unset=True)) + echo.print_model(jobs) @app.command( @@ -363,7 +362,7 @@ def verify( with ExceptionHandler(debug=debug): client = EndpointsClient(API_HOST) verification = client.verify_proof(endpoint_id, proof_id) - print_json(verification.model_dump_json(exclude_unset=True)) + echo.print_model(verification) @app.command( diff --git a/giza/cli/commands/models.py b/giza/cli/commands/models.py index b2a957b..088c680 100644 --- a/giza/cli/commands/models.py +++ b/giza/cli/commands/models.py @@ -4,7 +4,6 @@ import typer from pydantic import ValidationError from requests import HTTPError -from rich import print_json from giza.cli import API_HOST from giza.cli.client import ModelsClient @@ -68,7 +67,7 @@ def get( if debug: raise e sys.exit(1) - print_json(model.model_dump_json()) + echo.print_model(model) @app.command( @@ -122,7 +121,7 @@ def list( if debug: raise e sys.exit(1) - print_json(models.model_dump_json()) + echo.print_model(models) @app.command( @@ -182,4 +181,4 @@ def create( if debug: raise e sys.exit(1) - print_json(model.model_dump_json()) + echo.print_model(model) diff --git a/giza/cli/commands/users.py b/giza/cli/commands/users.py index 4242383..7fb003a 100644 --- a/giza/cli/commands/users.py +++ b/giza/cli/commands/users.py @@ -5,7 +5,6 @@ from email_validator import EmailNotValidError, validate_email from pydantic import SecretStr, ValidationError from requests import HTTPError -from rich import print_json from giza.cli import API_HOST from giza.cli.client import UsersClient @@ -211,7 +210,7 @@ def me(debug: Optional[bool] = DEBUG_OPTION) -> None: client = UsersClient(API_HOST, debug=debug) user = client.me() - print_json(user.model_dump_json()) + echo.print_model(user) @app.command( diff --git a/giza/cli/commands/versions.py b/giza/cli/commands/versions.py index 09f3c04..6bb8d7b 100644 --- a/giza/cli/commands/versions.py +++ b/giza/cli/commands/versions.py @@ -6,7 +6,6 @@ from typing import Dict, Optional import typer -from rich import print_json from giza.cli import API_HOST from giza.cli.client import TranspileClient, VersionsClient @@ -59,7 +58,7 @@ def get( with ExceptionHandler(debug=debug): client = VersionsClient(API_HOST) version: Version = client.get(model_id, version_id) - print_json(version.model_dump_json(exclude={"logs"})) + echo.print_model(version) def transpile( @@ -172,7 +171,7 @@ def update( zip_path = zip_folder(model_path, tmp_dir) version = client.upload_cairo(model_id, version_id, zip_path) echo("Version updated ✅ ") - print_json(version.model_dump_json()) + echo.print_model(version) @app.command( @@ -193,7 +192,7 @@ def list( with ExceptionHandler(debug=debug): client = VersionsClient(API_HOST) versions: VersionList = client.list(model_id) - print_json(versions.model_dump_json()) + echo.print_model(versions) @app.command( diff --git a/giza/cli/commands/workspaces.py b/giza/cli/commands/workspaces.py index 458cf37..c9d97b1 100644 --- a/giza/cli/commands/workspaces.py +++ b/giza/cli/commands/workspaces.py @@ -5,7 +5,6 @@ import typer from pydantic import ValidationError from requests import HTTPError -from rich import print_json from rich.live import Live from giza.cli import API_HOST @@ -61,7 +60,7 @@ def get( echo.error("⛔️Please delete the workspace and create a new one⛔️") else: echo.info(f"✅ Workspace URL: {workspace.url} ✅") - print_json(workspace.model_dump_json()) + echo.print_model(workspace) @app.command( diff --git a/giza/cli/frameworks/cairo.py b/giza/cli/frameworks/cairo.py index 8aca437..48ece5e 100644 --- a/giza/cli/frameworks/cairo.py +++ b/giza/cli/frameworks/cairo.py @@ -8,7 +8,6 @@ import typer from pydantic import ValidationError from requests import HTTPError -from rich import print_json from rich.live import Live from rich.progress import Progress, SpinnerColumn, TextColumn from rich.spinner import Spinner @@ -100,7 +99,7 @@ def prove( proof_client = ProofsClient(API_HOST) proof: Proof = proof_client.get_by_job_id(current_job.id) echo("Proof metrics:") - print_json(json.dumps(proof.metrics)) + echo.print_model(proof) f.write(proof_client.download(proof.id)) echo(f"Proof saved at: {output_path}") except ValidationError as e: diff --git a/giza/cli/frameworks/ezkl.py b/giza/cli/frameworks/ezkl.py index a8f462a..7a79e48 100644 --- a/giza/cli/frameworks/ezkl.py +++ b/giza/cli/frameworks/ezkl.py @@ -6,7 +6,6 @@ from pydantic import ValidationError from requests import HTTPError -from rich import print_json from rich.live import Live from rich.progress import Progress, Spinner, SpinnerColumn, TextColumn @@ -208,7 +207,7 @@ def prove( proof: Proof = proof_client.get_by_job_id(current_job.id) echo(f"Proof created with id -> {proof.id} ✅") echo("Proof metrics:") - print_json(json.dumps(proof.metrics)) + echo.print_model(proof) f.write(proof_client.download(proof.id)) echo(f"Proof saved at: {output_path}") except ValidationError as e: diff --git a/giza/cli/utils/echo.py b/giza/cli/utils/echo.py index 964ea3a..ccae417 100644 --- a/giza/cli/utils/echo.py +++ b/giza/cli/utils/echo.py @@ -186,7 +186,7 @@ def print_model(self, model: Union[BaseModel, RootModel], title=""): # We pick the first model to get the fields try: for field in model.root[0].model_fields.keys(): - table.add_column(field) + table.add_column(field, overflow="fold") except IndexError: return for m in model.root: diff --git a/tests/commands/test_agents.py b/tests/commands/test_agents.py index 6004cd3..e0213e7 100644 --- a/tests/commands/test_agents.py +++ b/tests/commands/test_agents.py @@ -124,7 +124,7 @@ def test_create_agent_with_endpoint_id(): mock_endpoints.assert_called_once() assert result.exit_code == 0 assert "Using endpoint id to create agent" in result.output - assert "test agent endpoint" in result.output + assert "test agent" in result.output def test_create_agent_no_ids(): diff --git a/tests/commands/test_endpoints.py b/tests/commands/test_endpoints.py index 2175c06..154004c 100644 --- a/tests/commands/test_endpoints.py +++ b/tests/commands/test_endpoints.py @@ -141,8 +141,8 @@ def test_list_deployments(): ) mock_list.assert_called_once() assert result.exit_code == 0 - assert "giza-deployment-1" in result.stdout - assert "giza-deployment-2" in result.stdout + assert "giza-" in result.stdout + assert "2" in result.stdout def test_create_deployments_empty(): @@ -206,7 +206,9 @@ def test_get_deployment(): ) mock_deployment.assert_called_once() assert result.exit_code == 0 - assert "giza-deployment-1" in result.stdout + assert "giza-" in result.stdout + assert "size" in result.stdout + assert "S" in result.stdout def test_get_deployment_http_error(): @@ -243,4 +245,5 @@ def test_endpoints_verify(): ) mock_verify.assert_called_once() assert result.exit_code == 0 - assert ' "verification": true' in result.stdout + assert "verification" in result.stdout + assert "True" in result.stdout From 271ad138a8447ac8072125e3b75aab84c3747c38 Mon Sep 17 00:00:00 2001 From: Gonzalo Mellizo-Soto Date: Thu, 6 Jun 2024 20:16:10 +0200 Subject: [PATCH 4/5] Add basics for json output and add it to users --- giza/cli/commands/users.py | 10 ++++++++-- giza/cli/options.py | 6 ++++++ giza/cli/utils/echo.py | 40 ++++++++++++++++++++++++++++++++++---- tests/utils/test_echo.py | 2 +- 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/giza/cli/commands/users.py b/giza/cli/commands/users.py index 7fb003a..1ac2bfb 100644 --- a/giza/cli/commands/users.py +++ b/giza/cli/commands/users.py @@ -9,7 +9,7 @@ from giza.cli import API_HOST from giza.cli.client import UsersClient from giza.cli.exceptions import PasswordError -from giza.cli.options import DEBUG_OPTION +from giza.cli.options import DEBUG_OPTION, JSON_OPTION from giza.cli.schemas import users from giza.cli.utils import echo, get_response_info from giza.cli.utils.misc import _check_password_strength @@ -199,13 +199,19 @@ def create_api_key( Verification and an active token is needed. """, ) -def me(debug: Optional[bool] = DEBUG_OPTION) -> None: +def me( + debug: Optional[bool] = DEBUG_OPTION, json: Optional[bool] = JSON_OPTION +) -> None: """ Retrieve information about the current user and print it as json to stdout. Args: debug (Optional[bool], optional): Whether to add debug information, will show requests, extra logs and traceback if there is an Exception. Defaults to DEBUG_OPTION (False) """ + + if json: + echo.set_log_file() + echo("Retrieving information about me!") client = UsersClient(API_HOST, debug=debug) user = client.me() diff --git a/giza/cli/options.py b/giza/cli/options.py index c5ec5ad..1f329bf 100644 --- a/giza/cli/options.py +++ b/giza/cli/options.py @@ -40,3 +40,9 @@ help="The input data to use", ) NAME_OPTION = typer.Option(None, "--name", "-n", help="The name of the resource") +JSON_OPTION = typer.Option( + None, + "--json", + "-j", + help="Whether to print the output as JSON. This will make that the only ouput is the json and the logs will be saved to `giza.log`", +) diff --git a/giza/cli/utils/echo.py b/giza/cli/utils/echo.py index ccae417..0ab8f29 100644 --- a/giza/cli/utils/echo.py +++ b/giza/cli/utils/echo.py @@ -1,10 +1,12 @@ +import atexit import datetime as dt +from io import TextIOWrapper from typing import Optional, Union import typer from pydantic import BaseModel, RootModel from rich import print as rich_print -from rich import reconfigure +from rich import print_json, reconfigure from rich.console import Console from rich.table import Table @@ -18,8 +20,32 @@ class Echo: Provides utilities to print different levels of the messages and provides formatting capabilities to each of the levels. """ - def __init__(self, debug: Optional[bool] = False) -> None: + LOG_FILE: str = "giza.log" + + def __init__( + self, debug: Optional[bool] = False, output_json: bool = False + ) -> None: self._debug = debug + self._json = output_json + self._file: TextIOWrapper | None = None + + if self._json: + self.set_log_file() + + def set_log_file(self) -> None: + """ + Set the log file to use for the echo class, use manually when needed + """ + self._json = True + self._file = open(self.LOG_FILE, "w") + atexit.register(self._close) + + def _close(self) -> None: + """ + Close the file if it was opened + """ + if self._json and self._file is not None: + self._file.close() def format_message( self, message: str, field: str = "giza", color: str = "orange3" @@ -91,12 +117,15 @@ def echo(self, message: str, formatted: str) -> None: formatted (str): formatted message """ try: - rich_print(formatted) + rich_print(formatted, file=self._file) except (UnicodeDecodeError, UnicodeEncodeError, UnicodeError): # fallback to the standard print behaviour formatted_time = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] formatted_message = f"[giza][{formatted_time}] {message}" - typer.echo(formatted_message) + if self._json and self._file is not None: + self._file.write(formatted_message + "\n") + else: + typer.echo(formatted_message) def error(self, message: str) -> None: """ @@ -176,6 +205,9 @@ def print_model(self, model: Union[BaseModel, RootModel], title=""): model (Union[BaseModel, RootModel]): The model or list of models to print title (str, optional): Title of the table. Defaults to "". """ + if self._json: + print_json(model.model_dump_json()) + return table = Table(title=title) console = Console() diff --git a/tests/utils/test_echo.py b/tests/utils/test_echo.py index 6d3ba19..1952f57 100644 --- a/tests/utils/test_echo.py +++ b/tests/utils/test_echo.py @@ -97,7 +97,7 @@ def test_echo_ok(): formatted = f"[red]{message}[/red]" with patch("giza.cli.utils.echo.rich_print") as print_mock: echo.echo(message, formatted) - print_mock.assert_called_once_with(formatted) + print_mock.assert_called_once_with(formatted, file=None) def test_echo_ok_fallback(capsys): From ec42795f8273cab2a52315fbd0c776bbb5347346 Mon Sep 17 00:00:00 2001 From: Gonzalo Mellizo-Soto Date: Mon, 10 Jun 2024 18:24:32 +0200 Subject: [PATCH 5/5] Add json option for all resources --- giza/cli/commands/agents.py | 16 ++++++++++++++++ giza/cli/commands/endpoints.py | 19 +++++++++++++++++++ giza/cli/commands/models.py | 12 ++++++++++-- giza/cli/commands/versions.py | 13 +++++++++++++ giza/cli/frameworks/cairo.py | 5 ++++- giza/cli/options.py | 2 +- giza/cli/utils/echo.py | 5 +++-- 7 files changed, 66 insertions(+), 6 deletions(-) diff --git a/giza/cli/commands/agents.py b/giza/cli/commands/agents.py index 5f002f7..da7a4c6 100644 --- a/giza/cli/commands/agents.py +++ b/giza/cli/commands/agents.py @@ -13,6 +13,7 @@ DEBUG_OPTION, DESCRIPTION_OPTION, ENDPOINT_OPTION, + JSON_OPTION, MODEL_OPTION, NAME_OPTION, VERSION_OPTION, @@ -43,8 +44,13 @@ def create( endpoint_id: int = ENDPOINT_OPTION, name: Optional[str] = NAME_OPTION, description: Optional[str] = DESCRIPTION_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + + if json: + echo.set_log_file() + echo("Creating agent ✅ ") if not model_id and not version_id and not endpoint_id: @@ -122,8 +128,12 @@ def list( parameters: Optional[List[str]] = typer.Option( None, "--parameters", "-p", help="The parameters of the agent" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() + echo("Listing agents ✅ ") with ExceptionHandler(debug=debug): client = AgentsClient(API_HOST) @@ -148,8 +158,11 @@ def list( ) def get( agent_id: int = AGENT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting agent {agent_id} ✅ ") with ExceptionHandler(debug=debug): client = AgentsClient(API_HOST) @@ -195,8 +208,11 @@ def update( parameters: Optional[List[str]] = typer.Option( None, "--parameters", "-p", help="The parameters of the agent" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Updating agent {agent_id} ✅ ") with ExceptionHandler(debug=debug): client = AgentsClient(API_HOST) diff --git a/giza/cli/commands/endpoints.py b/giza/cli/commands/endpoints.py index 6d03dcc..fc5797b 100644 --- a/giza/cli/commands/endpoints.py +++ b/giza/cli/commands/endpoints.py @@ -12,6 +12,7 @@ DEBUG_OPTION, ENDPOINT_OPTION, FRAMEWORK_OPTION, + JSON_OPTION, MODEL_OPTION, VERSION_OPTION, ) @@ -76,8 +77,11 @@ def list( only_active: bool = typer.Option( False, "--only-active", "-a", help="Only list active endpoints" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo("Listing endpoints ✅ ") params = {} try: @@ -126,8 +130,11 @@ def list( ) def get( endpoint_id: int = ENDPOINT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting endpoint {endpoint_id} ✅ ") try: client = EndpointsClient(API_HOST) @@ -190,8 +197,11 @@ def delete_endpoint( ) def list_proofs( endpoint_id: int = ENDPOINT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting proofs from endpoint {endpoint_id} ✅ ") try: client = EndpointsClient(API_HOST) @@ -236,8 +246,11 @@ def get_proof( proof_id: str = typer.Option( None, "--proof-id", "-p", help="The ID or request id of the proof" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting proof from endpoint {endpoint_id} ✅ ") try: client = EndpointsClient(API_HOST) @@ -333,8 +346,11 @@ def download_proof( ) def list_jobs( endpoint_id: int = ENDPOINT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting jobs from endpoint {endpoint_id} ✅ ") with ExceptionHandler(debug=debug): client = EndpointsClient(API_HOST) @@ -356,8 +372,11 @@ def verify( proof_id: str = typer.Option( None, "--proof-id", "-p", help="The ID or request id of the proof" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Verifying proof from endpoint {endpoint_id} ✅ ") with ExceptionHandler(debug=debug): client = EndpointsClient(API_HOST) diff --git a/giza/cli/commands/models.py b/giza/cli/commands/models.py index 088c680..90d622f 100644 --- a/giza/cli/commands/models.py +++ b/giza/cli/commands/models.py @@ -7,7 +7,7 @@ from giza.cli import API_HOST from giza.cli.client import ModelsClient -from giza.cli.options import DEBUG_OPTION, DESCRIPTION_OPTION, MODEL_OPTION +from giza.cli.options import DEBUG_OPTION, DESCRIPTION_OPTION, JSON_OPTION, MODEL_OPTION from giza.cli.schemas.models import ModelCreate from giza.cli.utils import echo, get_response_info @@ -27,6 +27,7 @@ ) def get( model_id: int = MODEL_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: """ @@ -40,6 +41,8 @@ def get( ValidationError: input fields are validated, if these are not suitable the exception is raised HTTPError: request error to the API, 4XX or 5XX """ + if json: + echo.set_log_file() echo("Retrieving model information ✅ ") try: client = ModelsClient(API_HOST) @@ -81,6 +84,7 @@ def get( """, ) def list( + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: """ @@ -93,7 +97,8 @@ def list( ValidationError: input fields are validated, if these are not suitable the exception is raised HTTPError: request error to the API, 4XX or 5XX """ - + if json: + echo.set_log_file() echo("Listing models ✅ ") try: client = ModelsClient(API_HOST) @@ -138,6 +143,7 @@ def create( ..., "--name", "-n", help="Name of the model to be created" ), description: str = DESCRIPTION_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: """ @@ -151,6 +157,8 @@ def create( ValidationError: input fields are validated, if these are not suitable the exception is raised HTTPError: request error to the API, 4XX or 5XX """ + if json: + echo.set_log_file() if name is None or name == "": echo.error("Name is required") sys.exit(1) diff --git a/giza/cli/commands/versions.py b/giza/cli/commands/versions.py index 6bb8d7b..cbee757 100644 --- a/giza/cli/commands/versions.py +++ b/giza/cli/commands/versions.py @@ -15,6 +15,7 @@ DESCRIPTION_OPTION, FRAMEWORK_OPTION, INPUT_OPTION, + JSON_OPTION, MODEL_OPTION, OUTPUT_PATH_OPTION, VERSION_OPTION, @@ -49,8 +50,11 @@ def update_sierra(model_id: int, version_id: int, model_path: str): def get( model_id: int = MODEL_OPTION, version_id: int = VERSION_OPTION, + json: Optional[bool] = JSON_OPTION, debug: bool = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() if any([model_id is None, version_id is None]): echo.error("⛔️Model ID and version ID are required⛔️") sys.exit(1) @@ -79,6 +83,7 @@ def transpile( "--download-sierra", help="Download the siera file is the modle is fully compatible. CAIRO only.", ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: if framework == Framework.CAIRO: @@ -90,6 +95,7 @@ def transpile( output_path=output_path, download_model=download_model, download_sierra=download_sierra, + json=json, debug=debug, ) elif framework == Framework.EZKL: @@ -145,8 +151,12 @@ def update( model_path: str = typer.Option( None, "--model-path", "-M", help="Path of the model to update" ), + json: Optional[bool] = JSON_OPTION, debug: bool = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() + if any([model_id is None, version_id is None]): echo.error("⛔️Model ID and version ID are required to update the version⛔️") sys.exit(1) @@ -183,8 +193,11 @@ def update( ) def list( model_id: int = MODEL_OPTION, + json: Optional[bool] = JSON_OPTION, debug: bool = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() if model_id is None: echo.error("⛔️Model ID is required⛔️") sys.exit(1) diff --git a/giza/cli/frameworks/cairo.py b/giza/cli/frameworks/cairo.py index 48ece5e..fb764c8 100644 --- a/giza/cli/frameworks/cairo.py +++ b/giza/cli/frameworks/cairo.py @@ -228,6 +228,7 @@ def transpile( output_path: str, download_model: bool, download_sierra: bool, + json: Optional[bool], debug: Optional[bool], ) -> None: """ @@ -256,7 +257,7 @@ def transpile( ValidationError: If there is a validation error with the model or version. HTTPError: If there is an HTTP error while communicating with the server. """ - echo = Echo(debug=debug) + echo = Echo(debug=debug, output_json=json) if model_path is None: echo.error("No model name provided, please provide a model path ⛔️") sys.exit(1) @@ -400,6 +401,8 @@ def transpile( if debug: raise zip_error sys.exit(1) + echo.print_model(model, title="Model") + echo.print_model(version, title="Version") def verify( diff --git a/giza/cli/options.py b/giza/cli/options.py index 1f329bf..817ea98 100644 --- a/giza/cli/options.py +++ b/giza/cli/options.py @@ -41,7 +41,7 @@ ) NAME_OPTION = typer.Option(None, "--name", "-n", help="The name of the resource") JSON_OPTION = typer.Option( - None, + False, "--json", "-j", help="Whether to print the output as JSON. This will make that the only ouput is the json and the logs will be saved to `giza.log`", diff --git a/giza/cli/utils/echo.py b/giza/cli/utils/echo.py index 0ab8f29..2115630 100644 --- a/giza/cli/utils/echo.py +++ b/giza/cli/utils/echo.py @@ -23,7 +23,7 @@ class Echo: LOG_FILE: str = "giza.log" def __init__( - self, debug: Optional[bool] = False, output_json: bool = False + self, debug: Optional[bool] = False, output_json: bool | None = False ) -> None: self._debug = debug self._json = output_json @@ -205,8 +205,9 @@ def print_model(self, model: Union[BaseModel, RootModel], title=""): model (Union[BaseModel, RootModel]): The model or list of models to print title (str, optional): Title of the table. Defaults to "". """ - if self._json: + if self._json and self._file is not None: print_json(model.model_dump_json()) + self._file.write(model.model_dump_json(indent=4)) return table = Table(title=title)