diff --git a/.gitignore b/.gitignore index d07bbaf..ced95ee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,8 @@ -metrics/ mlruns/ scratch/ dataset/ plots/ data/ -# metrics/ docs/source diff --git a/metrics/results_baselines.csv b/metrics/results_baselines.csv index c48c274..8e21231 100644 --- a/metrics/results_baselines.csv +++ b/metrics/results_baselines.csv @@ -1,25 +1,49 @@ target,model_targets,model,nmr_only,dataset_fraction,max_evals,accuracy_mean,accuracy_lb,accuracy_hb,f1_mean,f1_lb,f1_hb -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.01,3,0.5058823529411764,0.46903747319483347,0.5427272326875194,0.525925742192223,0.4851878096297068,0.5666636747547392 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.01,3,0.06470588235294118,0.030635374759868085,0.09877638994601429,0.08636157959687371,0.03966676646122087,0.13305639273252656 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.01,3,0.03529411764705882,0.01595977555433784,0.05462845973977981,0.041181139122315594,0.017173054912510904,0.06518922333212028 -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.1,3,0.5041297935103245,0.47971363569262876,0.5285459513280203,0.507117600161758,0.4822071364397943,0.5320280638837217 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.1,3,0.06548672566371681,0.05819033705083736,0.07278311427659627,0.09252484300302127,0.08395637262641044,0.10109331337963211 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.1,3,0.05162241887905604,0.04166242641784486,0.06158241134026722,0.058516000791620684,0.04638311412994394,0.07064888745329742 -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.5,3,0.49965075669383,0.49337506938944115,0.505926443998219,0.5000578704106552,0.4935612606620888,0.5065544801592217 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.5,3,0.06187427240977881,0.058395766411096006,0.06535277840846161,0.082683127199205,0.07829081930268664,0.08707543509572335 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,0.5,3,0.05372526193247963,0.050042314398458236,0.05740820946650103,0.05802572886695826,0.05414810474032928,0.06190335299358724 -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,1.0,3,0.5044554455445545,0.4973103137192605,0.5116005773698484,0.5057493447398278,0.4986128796558878,0.5128858098237679 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,1.0,3,0.0627256843331392,0.060436946191957476,0.06501442247432092,0.0839813406413514,0.08049729203671255,0.08746538924599026 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_ligand,True,1.0,3,0.053377984857309255,0.05029433048203211,0.0564616392325864,0.056778787239151515,0.05353046998226371,0.06002710449603932 -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,3,0.6323529411764706,0.5656343461925937,0.6990715361603475,0.49350566500917703,0.40976646143588563,0.5772448685824684 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,3,0.6058823529411764,0.5591503831326147,0.6526143227497382,0.4590187951703406,0.40134371692629056,0.5166938734143907 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,3,0.07352941176470588,0.04026239466391031,0.10679642886550146,0.013180538072178938,0.0035781335667639056,0.022782942577593973 -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,3,0.5504347826086956,0.5318515382893808,0.5690180269280104,0.39115298476660565,0.3692923496172117,0.4130136199159996 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,3,0.48434782608695653,0.4540080750495723,0.5146875771243408,0.3170652635457896,0.2835364865123042,0.35059404057927507 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,3,0.04579710144927537,0.032119694592215864,0.05947450830633487,0.004588130729867313,0.0022739116381336392,0.0069023498216009855 -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,3,0.582213209733488,0.5653184319174177,0.5991079875495583,0.4287331732475139,0.4085006111881631,0.44896573530686473 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,3,0.5331402085747394,0.5231075895481314,0.5431728276013474,0.37089044921066844,0.3593313499013468,0.3824495485199901 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,3,0.07363847045191194,0.06268257012397949,0.08459437077984439,0.010439663956871964,0.007365014971372833,0.013514312942371095 -metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,3,0.5614443797320908,0.553622739904657,0.5692660195595246,0.4038103418492868,0.3945499679316787,0.41307071576689497 -X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,3,0.5454280722189866,0.5386138801269157,0.5522422643110575,0.38504003512364116,0.37709529074375503,0.3929847795035273 -E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,3,0.07443214909726266,0.07054085806720904,0.07832344012731628,0.010355603186241442,0.009309143273654555,0.01140206309882833 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.01,0,0.5011065812463599,0.4965175326444776,0.5056956298482421,0.5014009988423164,0.4967589741412302,0.5060430235434025 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.01,0,0.0787711124053582,0.076373215967909,0.0811690088428073,0.1051106235740993,0.1008328543117485,0.1093883928364501 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.01,0,0.0534653465346534,0.0517809762905,0.0551497167788069,0.0573288138474562,0.0553372163568755,0.0593204113380369 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.1,0,0.5039895165987187,0.4984222433086742,0.5095567898887632,0.504246990530499,0.4987109548060941,0.509783026254904 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.1,0,0.0612987769365171,0.0595718614642057,0.0630256924088285,0.0837710216423869,0.0806545778623966,0.0868874654223772 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.1,0,0.0527955736750145,0.0492725631884355,0.0563185841615936,0.0565026556395903,0.0528037260864958,0.0602015851926848 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.5,0,0.5038147932440302,0.4958390463414746,0.5117905401465859,0.5041135779082694,0.4960944460696257,0.512132709746913 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.5,0,0.0625800815375655,0.0607815000270324,0.0643786630480985,0.0837624472075781,0.0809442246601338,0.0865806697550224 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.5,0,0.0507862550960978,0.0484046007238828,0.0531679094683128,0.0542342326529993,0.0520807848007603,0.0563876805052384 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,1.0,0,0.4997087944088527,0.4913975950068232,0.5080199938108821,0.500011318464902,0.4918249287029309,0.508197708226873 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,1.0,0,0.0614735002912055,0.0584238075564098,0.0645231930260013,0.0816686938778326,0.0781750919214746,0.0851622958341906 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,1.0,0,0.0521840419336051,0.0502328284051481,0.0541352554620621,0.0563427574367348,0.0541715935384499,0.0585139213350196 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,0,0.0875072801397786,0.0831329211541471,0.0918816391254102,0.01413505179764,0.0127848826883133,0.0154852209069667 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,0,0.0926033779848573,0.0845917111091493,0.1006150448605652,0.0158704018957468,0.0132779342365727,0.018462869554921 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,0,0.0875072801397786,0.0831329211541471,0.0918816391254102,0.01413505179764,0.0127848826883133,0.0154852209069667 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,0,0.0712288875946418,0.0657956076933005,0.0766621674959831,0.0095570714232258,0.008181721990006,0.0109324208564457 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.01,0,0.5011065812463599,0.4965175326444776,0.5056956298482421,0.5014009988423164,0.4967589741412302,0.5060430235434025 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.01,0,0.0787711124053582,0.07637321596790903,0.08116900884280737,0.10511062357409932,0.10083285431174853,0.10938839283645012 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.01,0,0.05346534653465347,0.05178097629050004,0.0551497167788069,0.057328813847456236,0.055337216356875565,0.05932041133803691 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.1,0,0.5039895165987187,0.49842224330867424,0.5095567898887632,0.504246990530499,0.4987109548060941,0.509783026254904 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.1,0,0.06129877693651718,0.05957186146420578,0.06302569240882858,0.08377102164238696,0.08065457786239665,0.08688746542237728 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.1,0,0.052795573675014563,0.0492725631884355,0.056318584161593625,0.05650265563959038,0.05280372608649589,0.06020158519268487 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.5,0,0.5038147932440302,0.4958390463414746,0.5117905401465859,0.5041135779082694,0.49609444606962577,0.512132709746913 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.5,0,0.06258008153756552,0.06078150002703248,0.06437866304809854,0.08376244720757814,0.08094422466013382,0.08658066975502246 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,0.5,0,0.05078625509609784,0.0484046007238828,0.05316790946831288,0.05423423265299937,0.05208078480076033,0.05638768050523841 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,1.0,0,0.4997087944088527,0.4913975950068232,0.5080199938108821,0.500011318464902,0.49182492870293093,0.508197708226873 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,1.0,0,0.06147350029120559,0.0584238075564098,0.06452319302600137,0.08166869387783267,0.07817509192147466,0.08516229583419067 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_random_selector,True,1.0,0,0.05218404193360513,0.050232828405148165,0.0541352554620621,0.056342757436734815,0.05417159353844995,0.05851392133501968 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.34956280863871386,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.40366185443745684,0.43094565857870937 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.01,0,0.08750728013977868,0.08313292115414715,0.09188163912541021,0.014135051797640089,0.012784882688313383,0.015485220906966794 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.34956280863871386,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.40366185443745684,0.43094565857870937 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.1,0,0.09260337798485731,0.08459171110914934,0.10061504486056527,0.015870401895746893,0.013277934236572721,0.018462869554921064 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.34956280863871386,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.40366185443745684,0.43094565857870937 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,0.5,0,0.08750728013977868,0.08313292115414715,0.09188163912541021,0.014135051797640089,0.012784882688313383,0.015485220906966794 +metal,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,0,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.34956280863871386,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,0,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.40366185443745684,0.43094565857870937 +E_ligand,"['metal', 'X3_ligand', 'E_ligand']",baseline_most_often,True,1.0,0,0.07122888759464183,0.06579560769330053,0.07666216749598312,0.009557071423225882,0.008181721990006016,0.010932420856445748 diff --git a/metrics/results_multi_target.csv b/metrics/results_multi_target.csv new file mode 100644 index 0000000..120a078 --- /dev/null +++ b/metrics/results_multi_target.csv @@ -0,0 +1,25 @@ +target,model_targets,model,nmr_only,dataset_fraction,max_evals,accuracy_mean,accuracy_lb,accuracy_hb,f1_mean,f1_lb,f1_hb +metal,"['metal', 'E_ligand']",random_forest,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +E_ligand,"['metal', 'E_ligand']",random_forest,True,1.0,1,0.0712288875946418,0.0657956076933005,0.0766621674959831,0.0095570714232258,0.008181721990006,0.0109324208564457 +metal,"['metal', 'E_ligand']",extra_trees,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +E_ligand,"['metal', 'E_ligand']",extra_trees,True,1.0,1,0.1249271986022131,0.1172073201240264,0.1326470770803999,0.0338563856406792,0.0309963687490145,0.0367164025323438 +metal,"['metal', 'X3_ligand']",random_forest,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand']",random_forest,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +metal,"['metal', 'X3_ligand']",extra_trees,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +X3_ligand,"['metal', 'X3_ligand']",extra_trees,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,"['X3_ligand', 'E_ligand']",random_forest,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +E_ligand,"['X3_ligand', 'E_ligand']",random_forest,True,1.0,1,0.0712288875946418,0.0657956076933005,0.0766621674959831,0.0095570714232258,0.008181721990006,0.0109324208564457 +X3_ligand,"['X3_ligand', 'E_ligand']",extra_trees,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +E_ligand,"['X3_ligand', 'E_ligand']",extra_trees,True,1.0,1,0.1249271986022131,0.1172073201240264,0.1326470770803999,0.0338563856406792,0.0309963687490145,0.0367164025323438 +metal,"['metal', 'E_ligand', 'X3_ligand']",random_forest,False,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +E_ligand,"['metal', 'E_ligand', 'X3_ligand']",random_forest,False,1.0,1,0.0712288875946418,0.0657956076933005,0.0766621674959831,0.0095570714232258,0.008181721990006,0.0109324208564457 +X3_ligand,"['metal', 'E_ligand', 'X3_ligand']",random_forest,False,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +metal,"['metal', 'E_ligand', 'X3_ligand']",extra_trees,False,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +E_ligand,"['metal', 'E_ligand', 'X3_ligand']",extra_trees,False,1.0,1,0.1249271986022131,0.1172073201240264,0.1326470770803999,0.0338563856406792,0.0309963687490145,0.0367164025323438 +X3_ligand,"['metal', 'E_ligand', 'X3_ligand']",extra_trees,False,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +metal,"['metal', 'E_ligand', 'X3_ligand']",random_forest,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.34956280863871386,0.3682500864133434 +E_ligand,"['metal', 'E_ligand', 'X3_ligand']",random_forest,True,1.0,1,0.07122888759464183,0.06579560769330053,0.07666216749598312,0.009557071423225882,0.008181721990006016,0.010932420856445748 +X3_ligand,"['metal', 'E_ligand', 'X3_ligand']",random_forest,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.40366185443745684,0.43094565857870937 +metal,"['metal', 'E_ligand', 'X3_ligand']",extra_trees,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.34956280863871386,0.3682500864133434 +E_ligand,"['metal', 'E_ligand', 'X3_ligand']",extra_trees,True,1.0,1,0.12492719860221317,0.11720732012402643,0.1326470770803999,0.03385638564067921,0.03099636874901453,0.036716402532343886 +X3_ligand,"['metal', 'E_ligand', 'X3_ligand']",extra_trees,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.40366185443745684,0.43094565857870937 diff --git a/metrics/results_one_target.csv b/metrics/results_one_target.csv new file mode 100644 index 0000000..7e6dc47 --- /dev/null +++ b/metrics/results_one_target.csv @@ -0,0 +1,81 @@ +target,model_targets,model,nmr_only,dataset_fraction,max_evals,accuracy_mean,accuracy_lb,accuracy_hb,f1_mean,f1_lb,f1_hb +metal,['metal'],random_forest,True,0.01,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +metal,['metal'],random_forest,True,0.1,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +metal,['metal'],random_forest,True,0.5,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +metal,['metal'],random_forest,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +metal,['metal'],logistic_regression,True,0.01,1,0.8343040186371578,0.825730103387833,0.8428779338864825,0.8343250490971543,0.825782360271246,0.8428677379230627 +metal,['metal'],logistic_regression,True,0.1,1,0.8348864298194526,0.8288929456167322,0.840879914022173,0.834610615171053,0.828708515672909,0.8405127146691969 +metal,['metal'],logistic_regression,True,0.5,1,0.8320034944670939,0.8278078256759076,0.8361991632582801,0.8318459611604336,0.8277165255108725,0.8359753968099948 +metal,['metal'],logistic_regression,True,1.0,1,0.8322655794991263,0.8281213081236869,0.8364098508745657,0.8321224573549054,0.8280365902739946,0.8362083244358163 +metal,['metal'],gradient_boosting,True,0.01,1,0.8073383808969133,0.796167949663936,0.8185088121298907,0.8066755678622277,0.7955687078624982,0.8177824278619572 +metal,['metal'],gradient_boosting,True,0.1,1,0.7835760046592894,0.7726788710900602,0.7944731382285187,0.7829576911769481,0.771943353531518,0.7939720288223783 +metal,['metal'],gradient_boosting,True,0.5,1,0.7862842166569599,0.7781220297446891,0.7944464035692307,0.784904554070982,0.7767403546676637,0.7930687534743004 +metal,['metal'],gradient_boosting,True,1.0,1,0.7864298194525334,0.77813941536259,0.7947202235424768,0.7852154832120501,0.7769425611867106,0.7934884052373897 +metal,['metal'],svc,True,0.01,1,0.5115026208503203,0.5031012623511539,0.5199039793494866,0.4991979766766928,0.4898739455100052,0.5085220078433804 +metal,['metal'],svc,True,0.1,1,0.4922248107163657,0.4844967856940064,0.4999528357387249,0.4886566034431118,0.4798741771938942,0.4974390296923293 +metal,['metal'],svc,True,0.5,1,0.5005824111822947,0.4938414959274986,0.5073233264370909,0.498411290949717,0.4909268707880524,0.5058957111113814 +metal,['metal'],svc,True,1.0,1,0.4955445544554455,0.4895069452568495,0.5015821636540414,0.4935905393620959,0.4867942320540067,0.5003868466701853 +metal,['metal'],extra_trees,True,0.01,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +metal,['metal'],extra_trees,True,0.1,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +metal,['metal'],extra_trees,True,0.5,1,0.6499708794408853,0.6418178245605629,0.6581239343212076,0.6151101456106673,0.6062261328282443,0.6239941583930904 +metal,['metal'],extra_trees,True,1.0,1,0.5226849155503785,0.5144697341323923,0.5309000969683648,0.3589064475260286,0.3495628086387138,0.3682500864133434 +X3_ligand,['X3_ligand'],random_forest,False,0.01,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],random_forest,False,0.1,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],random_forest,False,0.5,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],random_forest,False,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],logistic_regression,False,0.01,1,0.5105707629586488,0.4941409136932588,0.5270006122240387,0.4965394435459258,0.4791361771506772,0.5139427099411743 +X3_ligand,['X3_ligand'],logistic_regression,False,0.1,1,0.5729470005824111,0.5613068592799262,0.584587141884896,0.4784522950970181,0.4654989406356724,0.4914056495583637 +X3_ligand,['X3_ligand'],logistic_regression,False,0.5,1,0.5857600465928947,0.5755173170642385,0.596002776121551,0.4876580374237079,0.4758168736457143,0.4994992012017015 +X3_ligand,['X3_ligand'],logistic_regression,False,1.0,1,0.5854105998835177,0.5745355017981558,0.5962856979688796,0.4890392451659134,0.4762098737360356,0.5018686165957913 +X3_ligand,['X3_ligand'],gradient_boosting,False,0.01,1,0.5686953989516599,0.5527615526086911,0.5846292452946287,0.4418639081665366,0.4260073130636935,0.4577205032693796 +X3_ligand,['X3_ligand'],gradient_boosting,False,0.1,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],gradient_boosting,False,0.5,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],gradient_boosting,False,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],svc,False,0.01,1,0.576703552708212,0.5652278005369593,0.5881793048794648,0.4291284758623678,0.4165167515257382,0.4417402001989973 +X3_ligand,['X3_ligand'],svc,False,0.1,1,0.5387885847408269,0.528207538908018,0.5493696305736359,0.4069264557099435,0.3951838987822181,0.418669012637669 +X3_ligand,['X3_ligand'],svc,False,0.5,1,0.4839254513686663,0.4755163198210584,0.4923345829162742,0.3824012809352141,0.3718288505940441,0.3929737112763841 +X3_ligand,['X3_ligand'],svc,False,1.0,1,0.4988060570762959,0.4932784182905224,0.5043336958620694,0.3867815890899114,0.3779846778329936,0.3955785003468292 +X3_ligand,['X3_ligand'],extra_trees,False,0.01,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],extra_trees,False,0.1,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],extra_trees,False,0.5,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],extra_trees,False,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],random_forest,True,0.01,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],random_forest,True,0.1,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],random_forest,True,0.5,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],random_forest,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],logistic_regression,True,0.01,1,0.5453407105416425,0.5317935710837208,0.5588878499995641,0.4156828608920712,0.4023027393983311,0.4290629823858112 +X3_ligand,['X3_ligand'],logistic_regression,True,0.1,1,0.567909143855562,0.5567587592061309,0.5790595285049931,0.4222840796935112,0.4097426400536114,0.434825519333411 +X3_ligand,['X3_ligand'],logistic_regression,True,0.5,1,0.5740244612696564,0.5625429962576471,0.5855059262816658,0.4270857851835308,0.4135381866223432,0.4406333837447185 +X3_ligand,['X3_ligand'],logistic_regression,True,1.0,1,0.5733255678509027,0.562388243040162,0.5842628926616434,0.4266481712133332,0.4143237895160948,0.4389725529105717 +X3_ligand,['X3_ligand'],gradient_boosting,True,0.01,1,0.5632207338380897,0.5518788501775506,0.5745626174986288,0.4262974855912184,0.4133449528492375,0.4392500183331993 +X3_ligand,['X3_ligand'],gradient_boosting,True,0.1,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],gradient_boosting,True,0.5,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],gradient_boosting,True,1.0,1,0.5719569015725102,0.5602773919071311,0.5836364112378893,0.4169268922810233,0.403191267679328,0.4306625168827185 +X3_ligand,['X3_ligand'],svc,True,0.01,1,0.4970879440885264,0.4854061434065414,0.5087697447705114,0.4011186781937813,0.387451386758047,0.4147859696295155 +X3_ligand,['X3_ligand'],svc,True,0.1,1,0.5177635410599883,0.5066399124886503,0.5288871696313264,0.4037669284008751,0.3903986716933493,0.4171351851084008 +X3_ligand,['X3_ligand'],svc,True,0.5,1,0.4993593476994758,0.4902674981362743,0.5084511972626773,0.3930110622425491,0.3818532175986829,0.4041689068864154 +X3_ligand,['X3_ligand'],svc,True,1.0,1,0.5070180547466512,0.4973128025604256,0.5167233069328767,0.3983345218565252,0.3857557277687919,0.4109133159442584 +X3_ligand,['X3_ligand'],extra_trees,True,0.01,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],extra_trees,True,0.1,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],extra_trees,True,0.5,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +X3_ligand,['X3_ligand'],extra_trees,True,1.0,1,0.5727722772277227,0.5613014630738213,0.5842430913816241,0.4173037565080831,0.4036618544374568,0.4309456585787093 +E_ligand,['E_ligand'],random_forest,True,0.01,1,0.08308095515433897,0.08064356974299355,0.08551834056568439,0.01276237998713228,0.012039963134827471,0.013484796839437088 +E_ligand,['E_ligand'],random_forest,True,0.1,1,0.08308095515433897,0.08064356974299355,0.08551834056568439,0.01276237998713228,0.012039963134827471,0.013484796839437088 +E_ligand,['E_ligand'],random_forest,True,0.5,1,0.09260337798485731,0.08459171110914934,0.10061504486056527,0.015870401895746893,0.013277934236572721,0.018462869554921064 +E_ligand,['E_ligand'],random_forest,True,1.0,1,0.07122888759464183,0.06579560769330053,0.07666216749598312,0.009557071423225882,0.008181721990006016,0.010932420856445748 +E_ligand,['E_ligand'],logistic_regression,True,0.01,1,0.32364589400116484,0.31959434104721873,0.32769744695511094,0.2938445932432909,0.28960898150064196,0.2980802049859398 +E_ligand,['E_ligand'],logistic_regression,True,0.1,1,0.4139196272568434,0.4082466983124881,0.4195925562011987,0.37889605684311006,0.3735003508952647,0.38429176279095545 +E_ligand,['E_ligand'],logistic_regression,True,0.5,1,0.4284507862550961,0.4218969774059309,0.43500459510426126,0.4013686612957589,0.39444311039289237,0.4082942121986254 +E_ligand,['E_ligand'],logistic_regression,True,1.0,1,0.43439138031450203,0.42892461389499353,0.4398581467340105,0.4034202897832273,0.3979001541058436,0.408940425460611 +E_ligand,['E_ligand'],gradient_boosting,True,0.01,1,0.23677926616191028,0.23165353708686973,0.24190499523695083,0.22328891505949117,0.21867613169848943,0.2279016984204929 +E_ligand,['E_ligand'],gradient_boosting,True,0.1,1,0.28168316831683166,0.2782056161570456,0.2851607204766177,0.24124403641774705,0.23820288012408192,0.2442851927114122 +E_ligand,['E_ligand'],gradient_boosting,True,0.5,1,0.27690739662201513,0.2683394164457738,0.28547537679825646,0.22320224804248814,0.21628645986481607,0.23011803622016022 +E_ligand,['E_ligand'],gradient_boosting,True,1.0,1,0.2750728013977869,0.26519482790926224,0.28495077488631154,0.22405311543606898,0.21518389192411322,0.23292233894802475 +E_ligand,['E_ligand'],svc,True,0.01,1,0.212085032032615,0.205653761465392,0.218516302599838,0.1624527095156013,0.15520998386352775,0.16969543516767488 +E_ligand,['E_ligand'],svc,True,0.1,1,0.1624927198602213,0.15590226458149706,0.16908317513894552,0.1434276390517042,0.138395856868238,0.1484594212351704 +E_ligand,['E_ligand'],svc,True,0.5,1,0.13031450203843914,0.1259028250500221,0.13472617902685619,0.11965645814897956,0.11567108347055052,0.1236418328274086 +E_ligand,['E_ligand'],svc,True,1.0,1,0.11755969714618522,0.11389448110409917,0.12122491318827128,0.11440719237097766,0.11143717008212865,0.11737721465982666 +E_ligand,['E_ligand'],extra_trees,True,0.01,1,0.14140943506115317,0.13608371290978788,0.14673515721251845,0.040596033879813156,0.03772298318142872,0.04346908457819759 +E_ligand,['E_ligand'],extra_trees,True,0.1,1,0.08750728013977868,0.08313292115414715,0.09188163912541021,0.014135051797640089,0.012784882688313383,0.015485220906966794 +E_ligand,['E_ligand'],extra_trees,True,0.5,1,0.1302853814793244,0.12306563229273813,0.13750513066591066,0.03567518635231863,0.03292960260731864,0.03842077009731862 +E_ligand,['E_ligand'],extra_trees,True,1.0,1,0.12492719860221317,0.11720732012402643,0.1326470770803999,0.03385638564067921,0.03099636874901453,0.036716402532343886 diff --git a/nmrcraft/analysis/plotting.py b/nmrcraft/analysis/plotting.py index bf46b4f..4103bcb 100644 --- a/nmrcraft/analysis/plotting.py +++ b/nmrcraft/analysis/plotting.py @@ -183,6 +183,7 @@ def plot_metric( metric="accuracy", iterative_column="model", xdata="dataset_fraction", + legend=True, ): _, colors, _ = style_setup() if iterative_column == "target": @@ -271,6 +272,10 @@ def convert_to_labels(target_list): .reset_index() ) + # desired_index = ["Metal", "E", "X3", "Metal & E", "Metal & X3", "X3 & E", "Metal & E & X3"] + # pivot_df = aggregated_data.pivot(index="xlabel", columns="target", values="accuracy_mean") + # new_df = pivot_df.reindex(desired_index) + # Pivot the aggregated data new_df = aggregated_data.pivot( index="xlabel", columns="target", values=metric + "_mean" diff --git a/nmrcraft/utils/general.py b/nmrcraft/utils/general.py index 420199f..30e149a 100644 --- a/nmrcraft/utils/general.py +++ b/nmrcraft/utils/general.py @@ -7,7 +7,7 @@ def add_rows_metrics( dataset_size, include_structural: bool, model_name: str, - max_evals: int, + max_evals: int = 0, ): # Add all the newly generated metrics to the unified dataframe targetwise for i in range(len(statistical_metrics[0])): @@ -27,3 +27,14 @@ def add_rows_metrics( ] unified_metrics.loc[len(unified_metrics)] = new_row return unified_metrics + + +def str2bool(value: str) -> bool: + """Function converts a string to boolean in a human expected way. + + Args (str): + as a string for example 'True' or 'true' or 't' + Returns (bool): + bool corresponding to if the was true or false + """ + return value.lower() in ("yes", "true", "t", "1") diff --git a/scripts/analysis/visualize_results.py b/scripts/analysis/visualize_results.py index 8739fc4..75749ef 100755 --- a/scripts/analysis/visualize_results.py +++ b/scripts/analysis/visualize_results.py @@ -24,7 +24,10 @@ def load_results(results_dir: str, baselines_dir: str, max_evals: int): def plot_exp_1( - df_base: pd.DataFrame, df_one: pd.DataFrame, metric: str = "accuracy" + df_base: pd.DataFrame, + df_one: pd.DataFrame, + metric: str = "accuracy", + legend: bool = True, ): """Plot single output models with baselines for accuracy/f1-score as a function of dataset size. @@ -118,25 +121,32 @@ def plot_exp_1( fontsize=35, ) - # Adding the legend on the right side - # ax.legend( - # title="Model", - # bbox_to_anchor=(1.05, 0.5), - # loc="center left", - # borderaxespad=0.0, - # fontsize=20, - # ) + # Adding the legend on the right side if metric is F1-Score + if legend: + ax.legend( + title="Model", + bbox_to_anchor=(1.05, 0.5), + loc="center left", + borderaxespad=0.0, + fontsize=20, + ) + plotname = f"plots/results/01_{target}_{metric}_legend.png" + else: + plotname = f"plots/results/01_{target}_{metric}.png" # Adjust the plot layout to accommodate the legend fig.subplots_adjust(right=0.75) plt.tight_layout() # Show plot - plt.savefig(f"plots/results/01_{target}_{metric}.png") + plt.savefig(plotname) def plot_exp_1_multi( - df_base: pd.DataFrame, df_one: pd.DataFrame, metric: str = "accuracy" + df_base: pd.DataFrame, + df_one: pd.DataFrame, + metric: str = "accuracy", + legend: bool = True, ): """Plot single output models with baselines for accuracy/f1-score as a function of dataset size. @@ -229,17 +239,22 @@ def plot_exp_1_multi( f"Model Performance by Dataset Size for {target_clean}", fontsize=35, ) - - ax.legend( - title="Model", - bbox_to_anchor=(1.05, 0.5), - loc="center left", - borderaxespad=0.0, - fontsize=25, - ) + if legend: + ax.legend( + title="Model", + bbox_to_anchor=(1.05, 0.5), + loc="center left", + borderaxespad=0.0, + fontsize=25, + ) + plotname = ( + f"plots/results/01_{target}_{metric}_multioutput_legend.png" + ) + else: + plotname = f"plots/results/01_{target}_{metric}_multioutput.png" fig.subplots_adjust(right=0.75) plt.tight_layout() - plt.savefig(f"plots/results/01_{target}_{metric}_multioutput.png") + plt.savefig(plotname) def plot_exp_2(df_one, df_multi): @@ -268,38 +283,6 @@ def plot_exp_2(df_one, df_multi): ) -# def plot_exp_3(df_one, df_multi): -# """Compare whether nmr-only is set to true or false -# plot for X3 (best one target model) the bar plot with/without ligands -# plot for metal & E & X3 (best multi target model) the bar plot with/withut ligands -# legens also below the plot itself -# """ -# df_combined = pd.concat([df_one, df_multi]) -# full_df = df_combined[df_combined["dataset_fraction"] == 1] - -# models = full_df["model"].unique() -# for model in models: -# sub_df = full_df[full_df["model"] == model] -# print(sub_df) -# plot_bar( -# sub_df, -# title=f"Accuracy for {model} Predictions", -# filename=f"plots/03_accuracy_{model}.png", -# metric="accuracy", -# iterative_column="target", -# xdata="xlabel", -# ) -# plot_bar( -# sub_df, -# title=f"F1-Score for {model} Predictions", -# filename=f"plots/03_f1-score_{model}.png", -# metric="f1", -# iterative_column="target", -# xdata="xlabel", -# ) -# return - - # Setup parser parser = argparse.ArgumentParser( description="Train a model with MLflow tracking." @@ -333,5 +316,9 @@ def plot_exp_2(df_one, df_multi): baselines_dir="metrics/", max_evals=args.max_evals, ) - plot_exp_1(df_base, df_one) - plot_exp_2(df_one, df_multi) + plot_exp_1(df_base=df_base, df_one=df_one, metric="accuracy") + plot_exp_1(df_base=df_base, df_one=df_one, metric="f1") + plot_exp_1_multi(df_base=df_base, df_one=df_one) + plot_exp_1(df_base=df_base, df_one=df_one, metric="f1", legend=False) + plot_exp_1_multi(df_base=df_base, df_one=df_one, legend=False) + plot_exp_2(df_one=df_one, df_multi=df_multi) diff --git a/scripts/reproduce_results.py b/scripts/reproduce_results.py index b6df197..53e3b39 100644 --- a/scripts/reproduce_results.py +++ b/scripts/reproduce_results.py @@ -1,19 +1,4 @@ -"""Scripts for reproducing all results shown in the report. - -The project consists of 3 main "experiments". that are all enabled by: -(i) loading, splitting and preprocessing data -(ii) declaring and hyperparameter-tune models by CV -(iii) training and evaluating models -(iv) plotting results -For three "experiments": -(i) -(ii) -(iii) -The following scipts are called in this script: -- analysis: analysing the dataset with PCA -- training: training single-target, multi-target and baseline models -- plotting: plotting the results as shown in the report -""" +"""Scripts for reproducing all results shown in the report.""" import argparse import shlex @@ -39,9 +24,13 @@ def run_script(script_name, targets, include_structural, max_evals): "--max_evals", str(max_evals), ] - print("---------------------------------------------------") + print( + "---------------------------------------------------------------------" + ) print(f"Running command: {' '.join(cmd)}") - print("---------------------------------------------------") + print( + "---------------------------------------------------------------------" + ) # pylint: disable=subprocess-run-check subprocess.run(cmd, check=True, shell=False) # noqa: S603 @@ -100,22 +89,66 @@ def run_multi_target_experiments(max_evals): ) -def plot_results(script_name: str): - cmd = ["python", script_name] - print("---------------------------------------------------") +def run_baselines(): + # Run the script scripts/training/baselines.py + cmd = ["python", "scripts/training/baselines.py"] + print( + "---------------------------------------------------------------------" + ) + print(f"Running command: {' '.join(cmd)}") + print( + "---------------------------------------------------------------------" + ) + + # pylint: disable=subprocess-run-check + subprocess.run(cmd, check=True, shell=False) # noqa: S603 + + return + + +def run_visualize_results(script_name: str, max_evals: int): + cmd = [ + "python", + script_name, + "--max_evals", + str(max_evals), + "-me", + str(max_evals), + ] + print( + "---------------------------------------------------------------------" + ) print(f"Running command: {' '.join(cmd)}") - print("---------------------------------------------------") + print( + "---------------------------------------------------------------------" + ) # pylint: disable=subprocess-run-check subprocess.run(cmd, check=True, shell=False) # noqa: S603 +def run_dataframe_statistics(): + cmd = [ + "python", + "scripts/analysis/dataset_statistics.py", + ] + print( + "---------------------------------------------------------------------" + ) + print(f"Running command: {' '.join(cmd)}") + print( + "---------------------------------------------------------------------" + ) + subprocess.run(cmd, check=True, shell=False) # noqa: S603 + + def main(): parser = argparse.ArgumentParser( description="Run reproducibility script for all experiments." ) parser.add_argument( "--max_evals", + "-me", type=int, default=1, help="Max evaluations for hyperparameter tuning.", @@ -123,9 +156,13 @@ def main(): args = parser.parse_args() # run baselines + run_baselines() + run_dataframe_statistics() run_one_target_experiments(args.max_evals) run_multi_target_experiments(args.max_evals) - plot_results("scripts/analysis/visualize_results.py") + run_visualize_results( + "scripts/analysis/visualize_results.py", max_evals=args.max_evals + ) if __name__ == "__main__": diff --git a/scripts/training/baselines.py b/scripts/training/baselines.py index 6019c26..6fcb8b8 100644 --- a/scripts/training/baselines.py +++ b/scripts/training/baselines.py @@ -19,12 +19,6 @@ description="Train a model with MLflow tracking." ) -parser.add_argument( - "--max_evals", - type=int, - default=3, - help="The max evaluations for the hyperparameter tuning with hyperopt", -) parser.add_argument( "--target", type=str, @@ -40,7 +34,7 @@ parser.add_argument( "--plot_folder", type=str, - default="plots/", + default="plots/baselines/", help="The Folder where the plots are saved", ) @@ -67,7 +61,7 @@ def main(args) -> pd.DataFrame: 1.0, ] models = [ - "baseline_random_ligand", + "baseline_random_selector", "baseline_most_often", ] @@ -105,7 +99,7 @@ def main(args) -> pd.DataFrame: y_labels, ) = data_loader.load_data() - if model_name == "baseline_random_ligand": + if model_name == "baseline_random_selector": multioutput_model = DummyClassifier(strategy="uniform") elif model_name == "baseline_most_often": multioutput_model = DummyClassifier( @@ -141,7 +135,6 @@ def main(args) -> pd.DataFrame: dataset_size, args.include_structural, model_name, - args.max_evals, ) return unified_metrics diff --git a/scripts/training/multi_targets.py b/scripts/training/multi_targets.py index c65e360..15d883f 100644 --- a/scripts/training/multi_targets.py +++ b/scripts/training/multi_targets.py @@ -12,7 +12,7 @@ from nmrcraft.models.model_configs import model_configs from nmrcraft.models.models import load_model from nmrcraft.training.hyperparameter_tune import HyperparameterTuner -from nmrcraft.utils.general import add_rows_metrics +from nmrcraft.utils.general import add_rows_metrics, str2bool # Setup MLflow mlflow.set_experiment("Final_results") @@ -36,8 +36,8 @@ ) parser.add_argument( "--include_structural", - type=bool, - default=False, + type=str2bool, + default="False", help="Handles if structural features will be included or only nmr tensors are used.", ) parser.add_argument( diff --git a/scripts/training/one_target.py b/scripts/training/one_target.py index 084bd09..ca9417f 100644 --- a/scripts/training/one_target.py +++ b/scripts/training/one_target.py @@ -12,7 +12,7 @@ from nmrcraft.models.model_configs import model_configs from nmrcraft.models.models import load_model from nmrcraft.training.hyperparameter_tune import HyperparameterTuner -from nmrcraft.utils.general import add_rows_metrics +from nmrcraft.utils.general import add_rows_metrics, str2bool # Setup MLflow mlflow.set_experiment("Final_Results") @@ -36,8 +36,8 @@ ) parser.add_argument( "--include_structural", - type=bool, - default=False, + type=str2bool, + default="False", help="Handles if structural features will be included or only nmr tensors are used.", ) parser.add_argument( @@ -161,7 +161,6 @@ def main(args) -> pd.DataFrame: # Add arguments args = parser.parse_args() args.target = [args.target] # FIXME - unified_metrics = main(args) # save all the results