Skip to content

Commit

Permalink
Merge pull request #54 from ukaea/samueljackson92/faster-index-loading
Browse files Browse the repository at this point in the history
Add faster loading through parquet
  • Loading branch information
samueljackson92 authored Jun 27, 2024
2 parents 83fc615 + d0dcc59 commit 0859724
Show file tree
Hide file tree
Showing 177 changed files with 2,578 additions and 7,924 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- name: Run tests
run: |
source venv/bin/activate
python -m pytest -rsx tests/ --data-path=/home/runner/work/fair-mast/fair-mast/tests/mock_data/mini
python -m pytest -rsx tests/ --data-path=/home/runner/work/fair-mast/fair-mast/tests/mock_data/index
ruff-code-check:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions dev/docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ services:
restart: always
volumes:
- ../../tests/mock_data:/code/data
- ../../data/index:/code/index
- ../../src:/code/src
ports:
- '8081:5000'
Expand Down
2 changes: 1 addition & 1 deletion docs/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ logo: assets/MAST_plasma_image.jpg

# Information about where the book exists on the web
repository:
url: https://github.com/samueljackson92/mast-book # Online location of your book
url: https://github.com/ukaea/fair-mast/ # Online location of your book
branch: main

exclude_patterns: [data/*, 'data']
Expand Down
7 changes: 3 additions & 4 deletions docs/config.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# host: https://mastapp.site
host: http://localhost:8081
rest_api: http://localhost:8081/json
host: https://mastapp.site
rest_api: http://mastapp.site/json
graphql_api: https://mastapp.site/graphql
s3_api: https://s3.echo.stfc.ac.uk
s3_api: https://s3.echo.stfc.ac.uk
186 changes: 53 additions & 133 deletions src/api/create.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
import numpy as np
from enum import Enum
from pathlib import Path
import pandas as pd
import dask
import click
import uuid
import pyarrow.parquet as pq
from tqdm import tqdm
from sqlalchemy_utils.functions import (
drop_database,
Expand Down Expand Up @@ -107,14 +109,14 @@ def create_user(self):

def create_cpf_summary(self, data_path: Path):
"""Create the CPF summary table"""
paths = data_path.glob("*_cpf_columns.parquet")
paths = data_path.glob("cpf/*_cpf_columns.parquet")
for path in paths:
df = pd.read_parquet(path)
df.to_sql("cpf_summary", self.uri, if_exists="replace")

def create_scenarios(self, data_path: Path):
"""Create the scenarios metadata table"""
shot_file_name = data_path.parent / "shot_metadata.parquet"
shot_file_name = data_path / "shots.parquet"
shot_metadata = pd.read_parquet(shot_file_name)
ids = shot_metadata["scenario_id"].unique()
scenarios = shot_metadata["scenario"].unique()
Expand All @@ -125,21 +127,26 @@ def create_scenarios(self, data_path: Path):

def create_shots(self, data_path: Path):
"""Create the shot metadata table"""
shot_file_name = data_path.parent / "shot_metadata.parquet"
shot_metadata = pd.read_parquet(shot_file_name)
sources_file = data_path / "sources.parquet"
sources_metadata = pd.read_parquet(sources_file)
shot_ids = sources_metadata.shot_id.unique()

shot_file_name = data_path / "shots.parquet"
shot_metadata = pd.read_parquet(shot_file_name)
shot_metadata = shot_metadata.loc[shot_metadata["shot_id"] <= LAST_MAST_SHOT]
shot_metadata["facility"] = "MAST"
shot_metadata = shot_metadata.loc[shot_metadata.shot_id.isin(shot_ids)]
shot_metadata = shot_metadata.set_index("shot_id", drop=True)
shot_metadata = shot_metadata.sort_index()

shot_metadata["scenario"] = shot_metadata["scenario_id"]
shot_metadata["facility"] = "MAST"
shot_metadata = shot_metadata.drop(["scenario_id", "reference_id"], axis=1)
shot_metadata["uuid"] = shot_metadata.index.map(get_dataset_uuid)
shot_metadata["url"] = (
"s3://mast/level1/shots/" + shot_metadata.index.astype(str) + ".zarr"
)

paths = data_path.glob("*_cpf_data.parquet")
paths = data_path.glob("cpf/*_cpf_data.parquet")
cpfs = []
for path in paths:
cpf_metadata = read_cpf_metadata(path)
Expand All @@ -148,153 +155,69 @@ def create_shots(self, data_path: Path):
cpfs.append(cpf_metadata)

cpfs = pd.concat(cpfs, axis=0)
cpfs = cpfs = cpfs.reset_index()
cpfs = cpfs.loc[cpfs.shot_id <= LAST_MAST_SHOT]
cpfs = cpfs.drop_duplicates(subset="shot_id")
cpfs = cpfs.set_index("shot_id")

shot_metadata = pd.merge(
shot_metadata,
cpfs,
left_on="shot_id",
right_on="shot_id",
how="inner",
how="left",
)

shot_metadata.to_sql("shots", self.uri, if_exists="append")

def create_signal_datasets(self, file_name: str, url_type: URLType = URLType.S3):
"""Create the signal metadata table"""
signal_dataset_metadata = pd.read_parquet(file_name)
signal_dataset_metadata = signal_dataset_metadata.loc[
~signal_dataset_metadata.uri.str.contains("mini")
]
signal_dataset_metadata = signal_dataset_metadata.loc[
~signal_dataset_metadata["type"].isna()
]

signal_dataset_metadata["name"] = signal_dataset_metadata["name"].map(
normalize_signal_name
)
signal_dataset_metadata["quality"] = signal_dataset_metadata["status"].map(
lookup_status_code
)

signal_dataset_metadata["dimensions"] = signal_dataset_metadata[
"dimensions"
].map(list)
signal_dataset_metadata["doi"] = ""

signal_dataset_metadata["url"] = signal_dataset_metadata["name"].map(
lambda name: f"s3://mast/{name}.zarr"
)

signal_dataset_metadata["signal_type"] = signal_dataset_metadata["type"]
signal_dataset_metadata["csd3_path"] = signal_dataset_metadata["uri"]

signal_metadata = signal_dataset_metadata[
[
# "context_",
"uuid",
"name",
"description",
"signal_type",
"quality",
"dimensions",
"rank",
"units",
"doi",
"url",
"csd3_path",
]
]
signal_metadata.to_sql(
"signal_datasets", self.uri, if_exists="append", index=False
)

def create_signals(self, data_path: Path):
logging.info(f"Loading signals from {data_path}/signals")
file_names = data_path.glob("signals/**/*.parquet")
file_names = list(file_names)
logging.info(f"Loading signals from {data_path}")
file_name = data_path / "signals.parquet"

parquet_file = pq.ParquetFile(file_name)
batch_size = 10000
n = math.ceil(parquet_file.scan_contents() / batch_size)
for batch in tqdm(parquet_file.iter_batches(batch_size=batch_size), total=n):
signals_metadata = batch.to_pandas()

for file_name in tqdm(file_names):
signals_metadata = pd.read_parquet(file_name)
signals_metadata = signals_metadata.rename(
columns=dict(shot_nums="shot_id")
)

if len(signals_metadata) == 0 or "shot_id" not in signals_metadata.columns:
continue

df = signals_metadata
df = df[df.shot_id <= LAST_MAST_SHOT].copy()
df = df.rename({"dataset_item_uuid": "uuid"}, axis=1)
df["uuid"] = [
get_dataset_item_uuid(item["name"], item["shot_id"])
for key, item in df.iterrows()
]
df = df[df.shot_id <= LAST_MAST_SHOT]
df = df.drop_duplicates(subset="uuid")

df["quality"] = df["status"].map(lookup_status_code)

df["shape"] = df["shape"].map(
lambda x: x.tolist() if x is not None else None
)
df["shape"] = df["shape"].map(lambda x: x.tolist())
df["dimensions"] = df["dimensions"].map(lambda x: x.tolist())

df["url"] = (
"s3://mast/shots/M9/" + df["shot_id"].map(str) + ".zarr/" + df["group"]
"s3://mast/level1/shots/"
+ df["shot_id"].map(str)
+ ".zarr/"
+ df["name"]
)

df["version"] = 0
df["signal_type"] = df["type"]

if "IMAGE_SUBCLASS" not in df:
df["IMAGE_SUBCLASS"] = None

df["subclass"] = df["IMAGE_SUBCLASS"]

if "format" not in df:
df["format"] = None

if "units" not in df:
df["units"] = ""

uda_attributes = ["uda_name", "mds_name", "file_name", "format"]
df = df.drop(uda_attributes, axis=1)
df["shot_id"] = df.shot_id.astype(int)
columns = [
"uuid",
"shot_id",
"quality",
"shape",
"name",
"url",
"version",
"units",
"signal_type",
"description",
"subclass",
"format",
]
df = df[columns]
df = df.set_index("shot_id")
df = df.set_index("shot_id", drop=True)
df["description"] = df.description.map(lambda x: "" if x is None else x)
df.to_sql("signals", self.uri, if_exists="append")

def create_sources(self, data_path: Path):
source_metadata = pd.read_parquet(data_path.parent / "sources_metadata.parquet")
source_metadata["name"] = source_metadata["source_alias"]
source_metadata["source_type"] = source_metadata["type"]
source_metadata = source_metadata[["description", "name", "source_type"]]
source_metadata = source_metadata.drop_duplicates()
source_metadata = source_metadata.sort_values("name")
source_metadata.to_sql("sources", self.uri, if_exists="append", index=False)

def create_shot_source_links(self, data_path: Path):
sources_metadata = pd.read_parquet(
data_path.parent / "sources_metadata.parquet"
)
sources_metadata["source"] = sources_metadata["source_alias"]
sources_metadata["quality"] = sources_metadata["status"].map(lookup_status_code)
sources_metadata["shot_id"] = sources_metadata["shot"].astype(int)
sources_metadata = sources_metadata[
["source", "shot_id", "quality", "pass", "format"]
]
sources_metadata = sources_metadata.sort_values("source")
sources_metadata.to_sql(
"shot_source_link", self.uri, if_exists="append", index=False
source_metadata = pd.read_parquet(data_path / "sources.parquet")
source_metadata = source_metadata.drop_duplicates("uuid")
source_metadata = source_metadata.loc[source_metadata.shot_id <= LAST_MAST_SHOT]
source_metadata["url"] = (
"s3://mast/level1/shots/"
+ source_metadata["shot_id"].map(str)
+ ".zarr/"
+ source_metadata["name"]
)
column_names = ["uuid", "shot_id", "name", "description", "quality", "url"]
source_metadata = source_metadata[column_names]
source_metadata.to_sql("sources", self.uri, if_exists="append", index=False)


def read_cpf_metadata(cpf_file_name: Path) -> pd.DataFrame:
Expand All @@ -320,22 +243,19 @@ def create_db_and_tables(data_path):

# populate the database tables
logging.info("Create CPF summary")
client.create_cpf_summary(data_path)
client.create_cpf_summary(data_path / "cpf")

logging.info("Create Scenarios")
client.create_scenarios(data_path)

logging.info("Create Shots")
client.create_shots(data_path)

logging.info("Create Signals")
client.create_signals(data_path)

logging.info("Create Sources")
client.create_sources(data_path)

logging.info("Create Shot Source Links")
client.create_shot_source_links(data_path)
logging.info("Create Signals")
client.create_signals(data_path)

client.create_user()

Expand Down
18 changes: 8 additions & 10 deletions src/api/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,21 +260,19 @@ def on_request_end(self):
)
class Shot:
@strawberry.field
def signal_datasets(
self,
limit: Optional[int] = None,
where: Optional[ShotWhereFilter] = None,
) -> List[strawberry.LazyType["Shot", __module__]]: # noqa: F821
results = do_where_child_member(self.signal_datasets, where)
def signals(
self, limit: Optional[int] = None, where: Optional[SignalWhereFilter] = None
) -> List[strawberry.LazyType["Signal", __module__]]: # noqa: F821
results = do_where_child_member(self.signals, where)
if limit is not None:
results = results[:limit]
return results

@strawberry.field
def signals(
self, limit: Optional[int] = None, where: Optional[SignalWhereFilter] = None
) -> List[strawberry.LazyType["Signal", __module__]]: # noqa: F821
results = do_where_child_member(self.signals, where)
def sources(
self, limit: Optional[int] = None, where: Optional[SourceWhereFilter] = None
) -> List[strawberry.LazyType["Source", __module__]]: # noqa: F821
results = do_where_child_member(self.sources, where)
if limit is not None:
results = results[:limit]
return results
Expand Down
Loading

0 comments on commit 0859724

Please sign in to comment.