Skip to content

Commit

Permalink
Merge pull request #30 from ukaea/james/test_db
Browse files Browse the repository at this point in the history
James/test db
  • Loading branch information
samueljackson92 authored May 16, 2024
2 parents b1a6be3 + 6533b5e commit f35e13c
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 34 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ docker exec -it mast-api python -m src.api.create /code/data/metadata/mini
### Running Unit Tests
Verify everything is setup correctly by running the unit tests.

To run the unit tests you may use `pytest` like so:
To run the unit tests, input the following command inside your environment:

```bash
python -m pytest tests
pytest -rsx tests/ --data-path="INSERT FULL PATH TO DATA HERE"
```

This will run some unit tests for the REST and GraphQL APIs against the data in the database.
The data path will be will be along the lines of `~/fair-mast/data/metadata/mini`.

This will run some unit tests for the REST and GraphQL APIs against a testing database, created from the data in `--data-path`.

### Uploading Data to the Minio Storage

Expand Down
1 change: 1 addition & 0 deletions src/api/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def create_database(self):
SQLModel.metadata.create_all(engine)
# recreate the engine/metadata object
self.metadata_obj, self.engine = connect(self.uri)
return engine

def create_cpf_summary(self, data_path: Path):
"""Create the CPF summary table"""
Expand Down
Empty file added tests/__init__.py
Empty file.
1 change: 0 additions & 1 deletion tests/archive/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
import xarray as xr


@pytest.fixture
def fake_dataset():
return xr.Dataset(
Expand Down
3 changes: 2 additions & 1 deletion tests/archive/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import xarray as xr
import pytest
pyuda_import = pytest.importorskip("pyuda")
from src.archive.reader import DatasetReader


def test_list_signals():
shot = 30420
reader = DatasetReader(shot)
Expand Down
1 change: 1 addition & 0 deletions tests/archive/test_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
pyuda_import = pytest.importorskip("pyuda")
import subprocess
import zarr
import xarray as xr
Expand Down
1 change: 1 addition & 0 deletions tests/archive/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
pyuda_import = pytest.importorskip("pyuda")
import zarr
import xarray as xr
import numpy as np
Expand Down
70 changes: 70 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker
from src.api.main import app
from src.api.database import get_db
import os
from sqlmodel import SQLModel, create_engine
from pathlib import Path
from sqlalchemy_utils.functions import (
drop_database,
database_exists,
create_database,
)
from src.api.create import DBCreationClient
from os.path import exists
import sys

# Fixture to get data path from command line
def pytest_addoption(parser):
parser.addoption(
"--data-path",
action="store",
default="~/data/metadata/mini",
help="Path to mini data directory",
)

@pytest.fixture(scope="session")
def data_path(request):
return request.config.getoption("--data-path")

# Set up the database URL
host = os.environ.get("DATABASE_HOST", "localhost")
SQLALCHEMY_DATABASE_TEST_URL = f"postgresql://root:root@{host}:5432/test_db"

# Fixture to create and drop the database
@pytest.fixture(scope="session")
def test_db(data_path):
data_path = Path(data_path)
client = DBCreationClient(SQLALCHEMY_DATABASE_TEST_URL)
engine = client.create_database()
client.create_cpf_summary(data_path)
client.create_scenarios(data_path)
client.create_shots(data_path)
client.create_signals(data_path)
client.create_sources(data_path)
client.create_shot_source_links(data_path)

TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

yield TestingSessionLocal()

drop_database(SQLALCHEMY_DATABASE_TEST_URL)

# Fixture to override the database dependency
@pytest.fixture
def override_get_db(test_db):
def override():
try:
db = test_db
yield db
finally:
db.close()

app.dependency_overrides[get_db] = override

# Fixture to create a client for testing
@pytest.fixture(scope="module")
def client():
with TestClient(app) as client:
yield client
2 changes: 2 additions & 0 deletions tests/test_archive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pathlib import Path
import pytest
pyuda_import = pytest.importorskip("pyuda")
from src.archive.main import _do_write_signal, DatasetReader, DatasetWriter, read_config, get_file_system

def test_write_diagnostic_signal(benchmark):
Expand Down
52 changes: 23 additions & 29 deletions tests/test_json.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,35 @@
import pytest
import pandas as pd
from fastapi.testclient import TestClient
from src.api.main import app, get_db, add_pagination


@pytest.fixture(scope="module")
def client():
get_db()
client = TestClient(app)
# Need to re-add pagination after creating the client
add_pagination(app)
return client
# ========= Tests ==========

def test_get_cpf(client, override_get_db):
response = client.get("/json/cpf_summary")
assert response.status_code == 200
data = response.json()
assert len(data['items']) == 50
assert "description" in data['items'][0]

def test_get_shots(client):
def test_get_shots(client, override_get_db):
response = client.get("json/shots")
data = response.json()
assert response.status_code == 200
assert len(data['items']) == 50
assert data['previous_page'] == None
assert data['previous_page'] is None


def test_get_shots_filter_shot_id(client):
def test_get_shots_filter_shot_id(client, override_get_db):
response = client.get("json/shots?filters=shot_id$geq:30000")
data = response.json()
assert response.status_code == 200
assert len(data['items']) == 50


def test_get_shot(client):
def test_get_shot(client, override_get_db):
response = client.get("json/shots/30420")
data = response.json()
assert response.status_code == 200
assert data["shot_id"] == 30420


def test_get_shot_aggregate(client):
def test_get_shot_aggregate(client, override_get_db):
response = client.get(
"json/shots/aggregate?data=shot_id$min:,shot_id$max:&groupby=campaign&sort=-min_shot_id"
)
Expand All @@ -45,22 +39,21 @@ def test_get_shot_aggregate(client):
assert data[0]["campaign"] == "M9"


def test_get_signals_aggregate(client):
def test_get_signals_aggregate(client, override_get_db):
response = client.get("json/signals/aggregate?data=shot_id$count:&groupby=quality")
data = response.json()
assert response.status_code == 200
assert len(data) == 1


def test_get_signals_for_shot(client):
def test_get_signals_for_shot(client, override_get_db):
response = client.get("json/shots/30471/signals")
data = response.json()
assert response.status_code == 200
assert len(data['items']) == 50
assert data['previous_page'] == None
assert data['previous_page'] is None


def test_get_signals(client):
def test_get_signals(client, override_get_db):
response = client.get("json/signals")
data = response.json()
assert response.status_code == 200
Expand All @@ -69,35 +62,36 @@ def test_get_signals(client):
assert len(data['items']) == 50


def test_get_cpf_summary(client):
def test_get_cpf_summary(client, override_get_db):
response = client.get("json/cpf_summary")
data = response.json()
assert response.status_code == 200
assert len(data['items']) == 50


def test_get_scenarios(client):
def test_get_scenarios(client, override_get_db):
response = client.get("json/scenarios")
data = response.json()
assert response.status_code == 200
assert len(data['items']) == 34


def test_get_sources(client):
def test_get_sources(client, override_get_db):
response = client.get("json/sources")
data = response.json()
assert response.status_code == 200
assert len(data['items']) == 50

def test_get_cursor(client):
def test_get_cursor(client, override_get_db):
response = client.get("json/signals")
first_page_data = response.json()
next_cursor = first_page_data['next_page']
next_response = client.get(f"json/signals?cursor={next_cursor}")
next_page_data = next_response.json()
assert next_page_data['current_page'] == next_cursor

def test_cursor_response(client):
def test_cursor_response(client, override_get_db):
response = client.get("json/signals")
data = response.json()
assert data['previous_page'] == None
assert data['previous_page'] is None

0 comments on commit f35e13c

Please sign in to comment.