-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
executable file
·68 lines (52 loc) · 1.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import logging
logging.basicConfig(
filename="output/log_file.log",
encoding="utf-8",
format="%(asctime)s | %(levelname)s | %(module)s | %(message)s",
level=logging.DEBUG,
)
logger = logging.getLogger(__name__)
logger.debug("=============== Started! ===============")
from post_processing import run, generate_average_result, compare_results
from cutoff_eval import cutoff_eval
from load_data import load_train_test
from Models.Baseline import Baseline
from Models.ModifiedBaseline import ModifiedBaseline
from Models.DeepBaseline import DeepBaseline
from Models.DepthWiseBaseline import DepthWiseBaseline
from Models.AcousticModel import AcousticModel
from Models.GenreModel import GenreModel
from Models.ResNet1D import ResNet1D
model_dict = {
"baseline": Baseline(),
"modifiedbaseline": ModifiedBaseline(),
"deepbaseline": DeepBaseline(),
"depthwisebaseline": DepthWiseBaseline(),
"acoustic": AcousticModel(),
"genre": GenreModel(),
"resnet": ResNet1D(),
}
model_names = list(model_dict.keys())
total_run = 10
## MAIN FUNCTION
# run(model_dict, total_run, use_long=True)
## ADD LONG AVERAGE
# model_long_names = [
# "baseline_long",
# "modifiedbaseline_long",
# "deepbaseline_long",
# "depthwisebaseline_long",
# "acoustic_long",
# "genre_long",
# "resnet_long",
# ]
# for name in model_long_names:
# generate_average_result(name)
## COMPARE RESULTS
# compare_results()
## CUTOFF EVAL
# cutoff_eval("baseline")
# cutoff_eval("genre")
# cutoff_eval("resnet")
## SONG NAME ID MAPPING
load_train_test(use_long=False)