Skip to content

Commit

Permalink
Refactoring of load model method
Browse files Browse the repository at this point in the history
  • Loading branch information
HardNorth committed Nov 21, 2023
1 parent 95c6810 commit 4caf23d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 38 deletions.
13 changes: 9 additions & 4 deletions app/machine_learning/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,22 @@ def __init__(self, folder: str, tags: str, *, object_saver: ObjectSaver = None,
else:
self.object_saver = ObjectSaver({CONFIG_KEY: 'filesystem', 'filesystemDefaultPath': folder})

def load_model(self, model_files: list[str]) -> list[Any]:
result = []
for file in model_files:
model = self.object_saver.get_project_object(os.path.join(self.folder, file), using_json=False)
if model is None:
raise ValueError(f'Unable to load model "{file}".')
result.append(model)
return result

def get_model_info(self):
folder_name = os.path.basename(self.folder.strip("/").strip("\\")).strip()
tags = self.tags
if folder_name:
tags = [folder_name] + self.tags
return tags

@abstractmethod
def load_model(self):
raise NotImplementedError('"load_model" method is not implemented!')

@abstractmethod
def save_model(self):
raise NotImplementedError('"save_model" method is not implemented!')
19 changes: 3 additions & 16 deletions app/machine_learning/models/custom_defect_type_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,11 @@


class CustomDefectTypeModel(DefectTypeModel):
project_id: int | str

def __init__(self, folder: str, app_config: dict[str, Any], project_id: int | str):
super().__init__(folder, 'custom defect type model')
self.project_id = project_id
self.object_saver = ObjectSaver(app_config)

def load_model(self):
self.count_vectorizer_models = self.object_saver.get_project_object(
os.path.join(self.folder, "count_vectorizer_models"), self.project_id, using_json=False)
assert len(self.count_vectorizer_models) > 0
self.models = self.object_saver.get_project_object(os.path.join(self.folder, "models"), self.project_id,
using_json=False)
assert len(self.models) > 0
super().__init__(folder, 'custom defect type model', object_saver=ObjectSaver(app_config, project_id))

def save_model(self):
self.object_saver.put_project_object(self.count_vectorizer_models,
os.path.join(self.folder, "count_vectorizer_models"), self.project_id,
using_json=False)
self.object_saver.put_project_object(self.models, os.path.join(self.folder, "models"), self.project_id,
using_json=False)
os.path.join(self.folder, "count_vectorizer_models"), using_json=False)
self.object_saver.put_project_object(self.models, os.path.join(self.folder, "models"), using_json=False)
21 changes: 3 additions & 18 deletions app/machine_learning/models/defect_type_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,16 @@

from app.commons.object_saving.object_saver import ObjectSaver
from app.machine_learning.models import MlModel
from app.utils import text_processing, utils
from app.utils import text_processing


class DefectTypeModel(MlModel):

def __init__(self, folder: str, tags: str = 'global defect type model', object_saver: ObjectSaver = None,
app_config: dict[str, Any] = None) -> None:
super().__init__(folder, tags, object_saver=object_saver, app_config=app_config)
self.count_vectorizer_models = {}
self.models = {}
self.load_model()

def load_model(self):
if not utils.validate_folder(self.folder):
raise ValueError(f'Invalid model folder path: {self.folder}')

count_vectorizer_models_file = os.path.join(self.folder, "count_vectorizer_models.pickle")
models_file = os.path.join(self.folder, "models.pickle")
if not utils.validate_file(count_vectorizer_models_file) or not utils.validate_file(models_file):
raise ValueError(f'Model folder path does not contains necessary files: {self.folder}')

with open(count_vectorizer_models_file, "rb") as f:
self.count_vectorizer_models = pickle.load(f)
with open(models_file, "rb") as f:
self.models = pickle.load(f)
self.count_vectorizer_models, self.models = self.load_model(
['count_vectorizer_models.pickle', 'models.pickle'])

def save_model(self):
os.makedirs(self.folder, exist_ok=True)
Expand Down

0 comments on commit 4caf23d

Please sign in to comment.