Skip to content

Commit

Permalink
Merge branch 'main' into docs/54-documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiago Würthner committed May 28, 2024
2 parents e2b5127 + 2a52229 commit b3f19dc
Show file tree
Hide file tree
Showing 23 changed files with 1,356 additions and 1,161 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mlruns/
scratch/
dataset/
data/
plots/
data/

docs/source

Expand Down
16 changes: 15 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,31 @@ FROM archlinux
# Install python, poetry and looks stuff from arch system repos
RUN pacman-key --init
RUN pacman-key --populate
RUN pacman -Syu python python-poetry ranger neovim eza git tree zsh openssh which neofetch github-cli make --noconfirm
RUN pacman -Syu python python-poetry ranger neovim eza git tree zsh openssh which neofetch github-cli make binutils gcc pkg-config fakeroot debugedit --noconfirm

# Set working directory and copy over config files and install python packages
RUN mkdir -p /home/steve

# Get Arial font
RUN cd /home/steve
WORKDIR /home/steve
RUN git clone https://aur.archlinux.org/ttf-ms-fonts.git
RUN mv ttf-ms-fonts/* .
RUN chmod 777 .
RUN runuser -unobody makepkg
RUN pacman -U ttf-ms-fonts*.pkg.tar.zst --noconfirm
RUN rm -r ./*

# Get Project
ADD https://api.github.com/repos/mlederbauer/NMRcraft/git/refs/heads/main version.json
RUN git clone https://github.com/mlederbauer/NMRcraft.git /home/steve/NMRcraft
WORKDIR /home/steve/NMRcraft
RUN echo "🚀 Creating virtual environment using pyenv and poetry"
RUN poetry install
RUN poetry run pre-commit install



# Quality of Life stuff
ADD https://api.github.com/repos/tiaguinho-code/Archpy_dots/git/refs/heads/main version.json
RUN git clone https://github.com/tiaguinho-code/Archpy_dots /home/steve/Archpy_dots
Expand Down
134 changes: 98 additions & 36 deletions nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Functions to plot."""

import os

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
Expand All @@ -15,11 +20,13 @@ def style_setup():
plt.rcParams["text.latex.preamble"] = r"\usepackage{sansmathfonts}"
plt.rcParams["axes.prop_cycle"] = cycler(color=colors)

# Use the first color from the custom color cycle
first_color = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
all_colors = [
plt.rcParams["axes.prop_cycle"].by_key()["color"][i]
for i in range(len(colors))
]
plt.rcParams["text.usetex"] = False

return cmap, colors, first_color
return cmap, colors, all_colors


def plot_predicted_vs_ground_truth(
Expand All @@ -33,7 +40,8 @@ def plot_predicted_vs_ground_truth(
Returns:
None
"""
_, _, first_color = style_setup()
_, _, colors = style_setup()
first_color = colors[0]
# Creating the plot
plt.figure(figsize=(10, 8))
plt.scatter(y_test, y_pred, color=first_color, edgecolor="k", alpha=0.6)
Expand Down Expand Up @@ -85,7 +93,7 @@ def plot_predicted_vs_ground_truth_density(


def plot_confusion_matrix(
cm, classes, title, path, full=True, columns_set=False
cm_list, y_labels, model_name, dataset_size, folder_path: str = "plots/"
):
"""
Plots the confusion matrix.
Expand All @@ -98,45 +106,27 @@ def plot_confusion_matrix(
Returns:
None
"""
_, _, _ = style_setup()
if full: # Plot one big cm
if not os.path.exists(folder_path):
os.makedirs(folder_path)
# _, _, _ = style_setup()
for target in y_labels:
file_path = os.path.join(
folder_path,
f"ConfusionMatrix_{model_name}_{dataset_size}_{target}.png",
)
cm = cm_list[target]
classes = y_labels[target]
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title(title)
plt.title(f"{target} Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.savefig(path)
plt.close()

elif not full: # Plot many small cms of each target
cms = []
for columns in columns_set: # Make list of confusion matrices
cms.append(
cm[
slice(columns[0], columns[-1] + 1),
slice(columns[0], columns[-1] + 1),
]
)
fig, axs = plt.subplots(nrows=len(cms), figsize=(10, 8 * len(cms)))
for i, sub_cm in enumerate(cms):
sub_classes = classes[
slice(columns_set[i][0], columns_set[i][-1] + 1)
]
axs[i].imshow(sub_cm, interpolation="nearest", cmap=plt.cm.Blues)
axs[i].set_title(f"Confusion Matrix {i+1}")
tick_marks = np.arange(len(sub_classes))
axs[i].set_xticks(tick_marks)
axs[i].set_xticklabels(sub_classes, rotation=45)
axs[i].set_yticks(tick_marks)
axs[i].set_yticklabels(sub_classes)
plt.tight_layout()
print(cm)
plt.savefig(path)
plt.savefig(file_path)
plt.close()


Expand Down Expand Up @@ -167,3 +157,75 @@ def plot_roc_curve(fpr, tpr, roc_auc, title, path):
plt.legend(loc="lower right")
plt.savefig(path)
plt.close()


def plot_with_without_ligands_bar(df):
categories = df["target"].unique()
_, _, colors = style_setup()
first_color = colors[0]
second_color = colors[1]

# Extract data

x_pos = np.arange(len(categories))
bar_width = 0.35

# Initialize plot
fig, ax = plt.subplots()

# Loop through each category and plot bars
for i, category in enumerate(categories):
subset = df[df["target"] == category]

# Means and error bars
means = subset["accuracy_mean"].values
errors = [
subset["accuracy_mean"].values
- subset["accuracy_lower_bd"].values,
subset["accuracy_upper_bd"].values
- subset["accuracy_mean"].values,
]

# Bar locations for the group
bar_positions = x_pos[i] + np.array([-bar_width / 2, bar_width / 2])

# Determine bar colors based on 'nmr_tensor_input_only' field
bar_colors = [
first_color if x else second_color
for x in subset["nmr_tensor_input_only"]
]

# Plotting the bars
ax.bar(
bar_positions,
means,
yerr=np.array(errors),
color=bar_colors,
align="center",
ecolor="black",
capsize=5,
width=bar_width,
)

# Labeling and aesthetics
ax.set_ylabel("Accuracy / %")
ax.set_xlabel("Target(s)")
ax.set_xticks(x_pos)
ax.set_xticklabels(categories)
ax.set_title("Accuracy Measurements with Error Bars")

handles = [
mpatches.Patch(color=first_color, label="With Ligand Info"),
mpatches.Patch(color=second_color, label="Without Ligand Info"),
]
ax.legend(handles=handles, loc="best", fontsize=20)
plt.tight_layout()
plt.savefig("plots/exp3_incorporate_ligand_info.png")
print("Saved to plots/exp3_incorporate_ligand_info.png")


if __name__ == "main":
import pandas as pd

df = pd.read_csv("dataset/path_to_results.csv")
plot_with_without_ligands_bar(df)
76 changes: 76 additions & 0 deletions nmrcraft/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Load and preprocess data."""

import os

import pandas as pd
from datasets import load_dataset


class DatasetLoadError(FileNotFoundError):
"""Exeption raised when the Dataloader could not find data/dataset.csv,
even after trying to generate it from huggingface"""

def __init__(self, t):
super().__init__(f"Could not load raw Dataset '{t}'")


class InvalidTargetError(ValueError):
"""Exception raised when the specified model name is not found."""

def __init__(self, t):
super().__init__(f"Invalid target '{t}'")


def filename_to_ligands(dataset: pd.DataFrame):
"""
Extract ligands from the filename and add as columns to the dataset.
Assumes that filenames are structured in a specific way that can be parsed into ligands.
"""
filename_parts = dataset["file_name"].str.split("_", expand=True)
dataset["metal"] = filename_parts.get(0)
dataset["geometry"] = filename_parts.get(1)
dataset["E_ligand"] = filename_parts.get(2)
dataset["X1_ligand"] = filename_parts.get(3)
dataset["X2_ligand"] = filename_parts.get(4)
dataset["X3_ligand"] = filename_parts.get(5)
dataset["X4_ligand"] = filename_parts.get(6)
dataset["L_ligand"] = filename_parts.get(7).fillna(
"none"
) # Fill missing L_ligand with 'none'
return dataset


def load_dummy_dataset_locally(datset_path: str = "tests/data.csv"):
dataset = pd.read_csv(datset_path)
return dataset


def load_dataset_from_hf(
dataset_name: str = "NMRcraft/nmrcraft", data_files: str = "all_no_nan.csv"
):
"""Load the dataset.
This function loads the dataset using the specified dataset name and data files.
It assumes that you have logged into the Hugging Face CLI prior to calling this function.
Args:
dataset_name (str, optional): The name of the dataset. Defaults to "NMRcraft/nmrcraft".
data_files (str, optional): The name of the data file. Defaults to 'all_no_nan.csv'.
Returns:
pandas.DataFrame: The loaded dataset as a pandas DataFrame.
"""
# Create data dir if needed
if not os.path.isdir("dataset"):
os.mkdir("dataset")
# Check if hf dataset is already downloaded, else download it and then load it
if not os.path.isfile("dataset/dataset.csv"):
dataset = load_dataset(dataset_name, data_files=data_files)[
"train"
].to_pandas()
dataset.to_csv("dataset/dataset.csv")
if os.path.isfile("dataset/dataset.csv"):
dataset = pd.read_csv("dataset/dataset.csv")
elif not os.path.isfile("dataset/dataset.csv"):
raise DatasetLoadError(FileNotFoundError)
return dataset
Loading

0 comments on commit b3f19dc

Please sign in to comment.