diff --git a/README.md b/README.md index 09bd209..ebf1865 100644 --- a/README.md +++ b/README.md @@ -62,13 +62,15 @@ docker exec -it mast-api python -m src.api.create /code/data/metadata/mini ### Running Unit Tests Verify everything is setup correctly by running the unit tests. -To run the unit tests you may use `pytest` like so: +To run the unit tests, input the following command inside your environment: ```bash -python -m pytest tests +pytest -rsx tests/ --data-path="INSERT FULL PATH TO DATA HERE" ``` -This will run some unit tests for the REST and GraphQL APIs against the data in the database. +The data path will be will be along the lines of `~/fair-mast/data/metadata/mini`. + +This will run some unit tests for the REST and GraphQL APIs against a testing database, created from the data in `--data-path`. ### Uploading Data to the Minio Storage diff --git a/src/api/create.py b/src/api/create.py index d8250bb..ac53c22 100644 --- a/src/api/create.py +++ b/src/api/create.py @@ -86,6 +86,7 @@ def create_database(self): SQLModel.metadata.create_all(engine) # recreate the engine/metadata object self.metadata_obj, self.engine = connect(self.uri) + return engine def create_cpf_summary(self, data_path: Path): """Create the CPF summary table""" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/archive/conftest.py b/tests/archive/conftest.py index 2503ac0..462af3c 100644 --- a/tests/archive/conftest.py +++ b/tests/archive/conftest.py @@ -2,7 +2,6 @@ import numpy as np import xarray as xr - @pytest.fixture def fake_dataset(): return xr.Dataset( diff --git a/tests/archive/test_reader.py b/tests/archive/test_reader.py index f2ec7e6..3527737 100644 --- a/tests/archive/test_reader.py +++ b/tests/archive/test_reader.py @@ -1,7 +1,8 @@ import xarray as xr +import pytest +pyuda_import = pytest.importorskip("pyuda") from src.archive.reader import DatasetReader - def test_list_signals(): shot = 30420 reader = DatasetReader(shot) diff --git a/tests/archive/test_task.py b/tests/archive/test_task.py index 5dcc7ca..696b9d3 100644 --- a/tests/archive/test_task.py +++ b/tests/archive/test_task.py @@ -1,4 +1,5 @@ import pytest +pyuda_import = pytest.importorskip("pyuda") import subprocess import zarr import xarray as xr diff --git a/tests/archive/test_writer.py b/tests/archive/test_writer.py index 6d668d6..bf9c014 100644 --- a/tests/archive/test_writer.py +++ b/tests/archive/test_writer.py @@ -1,4 +1,5 @@ import pytest +pyuda_import = pytest.importorskip("pyuda") import zarr import xarray as xr import numpy as np diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8febbfd --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,70 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import sessionmaker +from src.api.main import app +from src.api.database import get_db +import os +from sqlmodel import SQLModel, create_engine +from pathlib import Path +from sqlalchemy_utils.functions import ( + drop_database, + database_exists, + create_database, +) +from src.api.create import DBCreationClient +from os.path import exists +import sys + +# Fixture to get data path from command line +def pytest_addoption(parser): + parser.addoption( + "--data-path", + action="store", + default="~/data/metadata/mini", + help="Path to mini data directory", + ) + +@pytest.fixture(scope="session") +def data_path(request): + return request.config.getoption("--data-path") + +# Set up the database URL +host = os.environ.get("DATABASE_HOST", "localhost") +SQLALCHEMY_DATABASE_TEST_URL = f"postgresql://root:root@{host}:5432/test_db" + +# Fixture to create and drop the database +@pytest.fixture(scope="session") +def test_db(data_path): + data_path = Path(data_path) + client = DBCreationClient(SQLALCHEMY_DATABASE_TEST_URL) + engine = client.create_database() + client.create_cpf_summary(data_path) + client.create_scenarios(data_path) + client.create_shots(data_path) + client.create_signals(data_path) + client.create_sources(data_path) + client.create_shot_source_links(data_path) + + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + yield TestingSessionLocal() + + drop_database(SQLALCHEMY_DATABASE_TEST_URL) + +# Fixture to override the database dependency +@pytest.fixture +def override_get_db(test_db): + def override(): + try: + db = test_db + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override + +# Fixture to create a client for testing +@pytest.fixture(scope="module") +def client(): + with TestClient(app) as client: + yield client \ No newline at end of file diff --git a/tests/test_archive.py b/tests/test_archive.py index 08dae77..52e98f6 100644 --- a/tests/test_archive.py +++ b/tests/test_archive.py @@ -1,4 +1,6 @@ from pathlib import Path +import pytest +pyuda_import = pytest.importorskip("pyuda") from src.archive.main import _do_write_signal, DatasetReader, DatasetWriter, read_config, get_file_system def test_write_diagnostic_signal(benchmark): diff --git a/tests/test_json.py b/tests/test_json.py index 0811804..bdebcde 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,41 +1,35 @@ -import pytest -import pandas as pd -from fastapi.testclient import TestClient -from src.api.main import app, get_db, add_pagination - - -@pytest.fixture(scope="module") -def client(): - get_db() - client = TestClient(app) - # Need to re-add pagination after creating the client - add_pagination(app) - return client +# ========= Tests ========== +def test_get_cpf(client, override_get_db): + response = client.get("/json/cpf_summary") + assert response.status_code == 200 + data = response.json() + assert len(data['items']) == 50 + assert "description" in data['items'][0] -def test_get_shots(client): +def test_get_shots(client, override_get_db): response = client.get("json/shots") data = response.json() assert response.status_code == 200 assert len(data['items']) == 50 - assert data['previous_page'] == None + assert data['previous_page'] is None -def test_get_shots_filter_shot_id(client): +def test_get_shots_filter_shot_id(client, override_get_db): response = client.get("json/shots?filters=shot_id$geq:30000") data = response.json() assert response.status_code == 200 assert len(data['items']) == 50 -def test_get_shot(client): +def test_get_shot(client, override_get_db): response = client.get("json/shots/30420") data = response.json() assert response.status_code == 200 assert data["shot_id"] == 30420 -def test_get_shot_aggregate(client): +def test_get_shot_aggregate(client, override_get_db): response = client.get( "json/shots/aggregate?data=shot_id$min:,shot_id$max:&groupby=campaign&sort=-min_shot_id" ) @@ -45,22 +39,21 @@ def test_get_shot_aggregate(client): assert data[0]["campaign"] == "M9" -def test_get_signals_aggregate(client): +def test_get_signals_aggregate(client, override_get_db): response = client.get("json/signals/aggregate?data=shot_id$count:&groupby=quality") data = response.json() assert response.status_code == 200 assert len(data) == 1 - -def test_get_signals_for_shot(client): +def test_get_signals_for_shot(client, override_get_db): response = client.get("json/shots/30471/signals") data = response.json() assert response.status_code == 200 assert len(data['items']) == 50 - assert data['previous_page'] == None + assert data['previous_page'] is None -def test_get_signals(client): +def test_get_signals(client, override_get_db): response = client.get("json/signals") data = response.json() assert response.status_code == 200 @@ -69,27 +62,27 @@ def test_get_signals(client): assert len(data['items']) == 50 -def test_get_cpf_summary(client): +def test_get_cpf_summary(client, override_get_db): response = client.get("json/cpf_summary") data = response.json() assert response.status_code == 200 assert len(data['items']) == 50 -def test_get_scenarios(client): +def test_get_scenarios(client, override_get_db): response = client.get("json/scenarios") data = response.json() assert response.status_code == 200 assert len(data['items']) == 34 -def test_get_sources(client): +def test_get_sources(client, override_get_db): response = client.get("json/sources") data = response.json() assert response.status_code == 200 assert len(data['items']) == 50 -def test_get_cursor(client): +def test_get_cursor(client, override_get_db): response = client.get("json/signals") first_page_data = response.json() next_cursor = first_page_data['next_page'] @@ -97,7 +90,8 @@ def test_get_cursor(client): next_page_data = next_response.json() assert next_page_data['current_page'] == next_cursor -def test_cursor_response(client): +def test_cursor_response(client, override_get_db): response = client.get("json/signals") data = response.json() - assert data['previous_page'] == None \ No newline at end of file + assert data['previous_page'] is None +