diff --git a/models/classifier_builder.py b/models/classifier_builder.py index bb26a6a..b7f846f 100644 --- a/models/classifier_builder.py +++ b/models/classifier_builder.py @@ -1,8 +1,9 @@ import logging import numpy as np +import os from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, f1_score - +from sklearn.externals import joblib from .model_builder import ModelBuilder from .doc2vec_builder import doc2VecBuilder logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) @@ -24,8 +25,12 @@ def train_model(self, d2v, training_vectors, training_labels): logging.info('Training F1 score: {}'.format(f1_score(training_labels, training_predictions, average='weighted'))) - def save_model(self): - pass + def save_model(self, filename): + joblib.dump(self.model,"./classifiers/"+ filename) - def load_model(self): - pass \ No newline at end of file + def load_model(self,filename): + if (os.path.isfile('./classifiers/' + filename)): + loaded_model = joblib.load(filename) + self.model = loaded_model + else: + self.model = None \ No newline at end of file