Skip to content

Commit

Permalink
Added argps to visualizer to handle different max_evals in single
Browse files Browse the repository at this point in the history
metrics file
  • Loading branch information
Tiago Würthner committed May 29, 2024
1 parent 8024b16 commit 41bb2db
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions scripts/analysis/visualize_results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import ast
import os

Expand All @@ -6,7 +7,7 @@
from nmrcraft.analysis.plotting import plot_bar, plot_metric


def load_results(results_dir: str, baselines_dir: str):
def load_results(results_dir: str, baselines_dir: str, max_evals: int):
import_filename_base = os.path.join(baselines_dir, "results_baselines.csv")
import_filename_one = os.path.join(results_dir, "results_one_target.csv")
import_filename_multi = os.path.join(
Expand All @@ -15,8 +16,9 @@ def load_results(results_dir: str, baselines_dir: str):

df_base = pd.read_csv(import_filename_base)
df_one = pd.read_csv(import_filename_one)
df_one = df_one[df_one["max_evals"] == max_evals]
df_multi = pd.read_csv(import_filename_multi)

df_multi = df_multi[df_multi["max_evals"] == max_evals]
return df_base, df_one, df_multi


Expand Down Expand Up @@ -134,13 +136,29 @@ def plot_exp_3(df_one, df_multi):
return


# Setup parser
parser = argparse.ArgumentParser(
description="Train a model with MLflow tracking."
)

parser.add_argument(
"--max_evals",
type=int,
default=100,
help="How many max_evals the analysed data has",
)
# Add arguments
args = parser.parse_args()

if __name__ == "__main__":

if not os.path.exists("./plots/results/"):
os.makedirs("./plots/results/")

df_base, df_one, df_multi = load_results(
results_dir="metrics/20eval/", baselines_dir="metrics/"
results_dir="metrics/20eval/",
baselines_dir="metrics/",
max_evals=args.max_evals,
)

plot_exp_1(df_base, df_one)
Expand Down

0 comments on commit 41bb2db

Please sign in to comment.