From 2ff9f9a1f6a80003c70bdd55a130d001b764fd51 Mon Sep 17 00:00:00 2001 From: LukasFehring <72503857+LukasFehring@users.noreply.github.com> Date: Mon, 21 Nov 2022 14:56:05 +0100 Subject: [PATCH] Modify functionality to reset experiments (#77) * Add Enum `ExperimentStatus` * Modify `experimenter.reset_experiments()` to be able to get * single `ExperimentStatus` * list of `ExperimentStatus` * `"all"` to reset all `ExperimentStatus` * Add method `experimenter.delete_table()` * Add method `delete_table()` * Increase readability of SQL statements --- py_experimenter/database_connector.py | 69 ++++++++--- py_experimenter/experiment_status.py | 9 ++ py_experimenter/experimenter.py | 35 +++--- test/test_database_connector.py | 111 +++++++++++++++--- .../test_run_mysql_experiment.py | 14 +-- .../test_run_sqlite_experiment.py | 16 +-- 6 files changed, 180 insertions(+), 74 deletions(-) create mode 100644 py_experimenter/experiment_status.py diff --git a/py_experimenter/database_connector.py b/py_experimenter/database_connector.py index ee0c1c7b..09746456 100644 --- a/py_experimenter/database_connector.py +++ b/py_experimenter/database_connector.py @@ -1,12 +1,13 @@ import abc import logging from datetime import datetime -from typing import List, Tuple +from typing import List, Optional, Tuple import pandas as pd from py_experimenter import utils from py_experimenter.exceptions import DatabaseConnectionError, EmptyFillDatabaseCallError, NoExperimentsLeftException, TableHasWrongStructureError +from py_experimenter.experiment_status import ExperimentStatus class DatabaseConnector(abc.ABC): @@ -149,7 +150,7 @@ def fill_table(self, parameters=None, fixed_parameter_combinations=None) -> None if self._check_combination_in_existing_rows(combination, existing_rows, keyfield_names): continue values = list(combination.values()) - values.append("created") + values.append(ExperimentStatus.CREATED.value) values.append("%s" % time.strftime("%m/%d/%Y, %H:%M:%S")) self._write_to_database(column_names.split(', '), values) @@ -175,19 +176,17 @@ def get_experiment_configuration(self, random_order: bool): return experiment_id, dict(zip([i[0] for i in description], *values)) - def _execute_queries(self, connection, cursor, random_order) ->Tuple[int, List, List]: + def _execute_queries(self, connection, cursor, random_order) -> Tuple[int, List, List]: if random_order: order_by = self.__class__.random_order_string() else: order_by = "id" - select_experiment = f"SELECT id FROM {self.table_name} WHERE status = 'created' ORDER BY {order_by} LIMIT 1;" - alter_experiment = "UPDATE {} SET status = 'running' WHERE id = {};" - select_keyfields = "SELECT {} FROM {} WHERE id = {};" - self.execute(cursor, select_experiment) + + self.execute(cursor, f"SELECT id FROM {self.table_name} WHERE status = 'created' ORDER BY {order_by} LIMIT 1;") experiment_id = self.fetchall(cursor)[0][0] - self.execute(cursor, alter_experiment.format(self.table_name, experiment_id)) - self.execute(cursor, select_keyfields.format(','.join(utils.get_keyfield_names( - self.database_credential_file_path)), self.table_name, experiment_id)) + self.execute(cursor, f"UPDATE {self.table_name} SET status = '{ExperimentStatus.RUNNING.value}' WHERE id = {experiment_id};") + keyfields = ','.join(utils.get_keyfield_names(self.database_credential_file_path)) + self.execute(cursor, f"SELECT {keyfields} FROM {self.table_name} WHERE id = {experiment_id};") values = self.fetchall(cursor) self.commit(connection) description = cursor.description @@ -196,7 +195,7 @@ def _execute_queries(self, connection, cursor, random_order) ->Tuple[int, List, @abc.abstractstaticmethod def random_order_string(): pass - + @abc.abstractmethod def _pull_open_experiment(self, random_order) -> Tuple[int, List, List]: pass @@ -252,7 +251,28 @@ def not_executed_yet(self, where) -> bool: connection.close() return not_executed - def delete_experiments_with_status(self, status): + def reset_experiments(self, *states: str) -> None: + def get_dict_for_keyfields_and_rows(keyfields: List[str], rows: List[List[str]]) -> List[dict]: + return [{key: value for key, value in zip(keyfields, row)} for row in rows] + + for state in states: + keyfields, rows = self._pop_experiments_with_status(state) + rows = get_dict_for_keyfields_and_rows(keyfields, rows) + if rows: + self.fill_table(fixed_parameter_combinations=rows) + logging.info(f"{len(rows)} experiments with status {' '.join(list(states))} were reset") + + def _pop_experiments_with_status(self, status: Optional[str] = None) -> Tuple[List[str], List[List]]: + if status == ExperimentStatus.ALL.value: + condition = None + else: + condition = f"WHERE status = '{status}'" + + column_names, entries = self._get_experiments_with_condition(condition) + self._delete_experiments_with_condition(condition) + return column_names, entries + + def _get_experiments_with_condition(self, condition: Optional[str] = None) -> Tuple[List[str], List[List]]: def _get_keyfields_from_columns(column_names, entries): df = pd.DataFrame(entries, columns=column_names) keyfields = utils.get_keyfield_names(self.database_credential_file_path) @@ -261,21 +281,34 @@ def _get_keyfields_from_columns(column_names, entries): connection = self.connect() cursor = self.cursor(connection) - column_names = self.get_structure_from_table(cursor) - self.execute(cursor, f"SELECT * FROM {self.table_name} WHERE status='{status}'") - entries = cursor.fetchall() + query_condition= condition or '' + self.execute(cursor, f"SELECT * FROM {self.table_name} {query_condition}") + entries = self.fetchall(cursor) + column_names = self.get_structure_from_table(cursor) column_names, entries = _get_keyfields_from_columns(column_names, entries) - self.execute(cursor, f"DELETE FROM {self.table_name} WHERE status='{status}'") - self.commit(connection) - self.close_connection(connection) return column_names, entries + def _delete_experiments_with_condition(self, condition: Optional[str] = None) -> None: + connection = self.connect() + cursor = self.cursor(connection) + + query_condition = condition or '' + self.execute(cursor, f'DELETE FROM {self.table_name} {query_condition}') + self.commit(connection) + self.close_connection(connection) + @abc.abstractmethod def get_structure_from_table(self, cursor): pass + def delete_table(self) -> None: + connection = self.connect() + cursor = self.cursor(connection) + self.execute(cursor, f'DROP TABLE IF EXISTS {self.table_name}') + self.commit(connection) + def get_table(self) -> pd.DataFrame: connection = self.connect() query = f"SELECT * FROM {self.table_name}" diff --git a/py_experimenter/experiment_status.py b/py_experimenter/experiment_status.py new file mode 100644 index 00000000..339e6cb6 --- /dev/null +++ b/py_experimenter/experiment_status.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class ExperimentStatus(Enum): + CREATED = 'created' + RUNNING = 'running' + DONE = 'done' + ERROR = 'error' + ALL = 'all' diff --git a/py_experimenter/experimenter.py b/py_experimenter/experimenter.py index 830a261f..059e5317 100644 --- a/py_experimenter/experimenter.py +++ b/py_experimenter/experimenter.py @@ -12,6 +12,7 @@ from py_experimenter.database_connector_lite import DatabaseConnectorLITE from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL from py_experimenter.exceptions import InvalidConfigError, InvalidValuesInConfiguration, NoExperimentsLeftException +from py_experimenter.experiment_status import ExperimentStatus from py_experimenter.result_processor import ResultProcessor @@ -404,32 +405,30 @@ def _execution_wrapper(self, error_msg = traceback.format_exc() logging.error(error_msg) result_processor._write_error(error_msg) - result_processor._change_status('error') + result_processor._change_status(ExperimentStatus.ERROR.value) else: - result_processor._change_status('done') + result_processor._change_status(ExperimentStatus.DONE.value) - def reset_experiments(self, status) -> None: + def reset_experiments(self, *states: str) -> None: """ Deletes the experiments of the database table having the given `status`. Afterwards, all rows that have been deleted from the database table are added to the table again featuring `created` as a status. Experiments - to reset can be selected based on the following status: + to reset can be selected based on the following status definition: - * `created` when the experiment is added to the database table, execution has not started. - * `running` when the execution of the experiment has been started. - * `error` if something went wrong during the execution, i.e., an exception is raised - * `done` if the execution finished successfully. - - :param status: The status of experiments that should be reset. - :type status: str + :param status: The status of experiments that should be reset. Either `created`, `running`, `error`, `done`, or `all`. + Note that `states` is a variable length argument, so multiple states can be given as a list. + :type status: str """ - def get_dict_for_keyfields_and_rows(keyfields: List[str], rows: List[List[str]]) -> List[dict]: - return [{key: value for key, value in zip(keyfields, row)} for row in rows] + if not states: + logging.warning('No states given to reset experiments. No experiments are reset.') + else: + self.dbconnector.reset_experiments(*states) - keyfields, rows = self.dbconnector.delete_experiments_with_status(status) - rows = get_dict_for_keyfields_and_rows(keyfields, rows) - if rows: - self.fill_table_with_rows(rows) - logging.info(f"{len(rows)} experiments with status {status} were reset") + def delete_table(self) -> None: + """ + Drops the table defined in the configuration file. + """ + self.dbconnector.delete_table() def get_table(self) -> pd.DataFrame: """ diff --git a/test/test_database_connector.py b/test/test_database_connector.py index b35167bb..1b49066b 100644 --- a/test/test_database_connector.py +++ b/test/test_database_connector.py @@ -1,4 +1,3 @@ - import datetime import os @@ -9,6 +8,7 @@ from py_experimenter import database_connector, database_connector_mysql from py_experimenter.database_connector import DatabaseConnector from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL +from py_experimenter.experiment_status import ExperimentStatus from py_experimenter.utils import load_config @@ -72,26 +72,26 @@ def test_create_table_if_not_exists(create_database_if_not_existing_mock, test_c [], ['value,exponent,status,creation_date'], [ - [1, 3, 'created'], - [1, 4, 'created'], - [2, 3, 'created'], - [2, 4, 'created'] + [1, 3, ExperimentStatus.CREATED.value], + [1, 4, ExperimentStatus.CREATED.value], + [2, 3, ExperimentStatus.CREATED.value], + [2, 4, ExperimentStatus.CREATED.value] ]), (os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'), {}, [{'value': 1, 'exponent': 3}, {'value': 1, 'exponent': 4}], ['value,exponent,status,creation_date'], [ - [1, 3, 'created'], - [1, 4, 'created'], + [1, 3, ExperimentStatus.CREATED.value], + [1, 4, ExperimentStatus.CREATED.value], ]), (os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file_3_parameters.cfg'), {'value': [1, 2], }, [{'exponent': 3, 'other_value': 5}], ['value,exponent,other_value,status,creation_date'], [ - [1, 3, 5, 'created'], - [2, 3, 5, 'created'], + [1, 3, 5, ExperimentStatus.CREATED.value], + [2, 3, 5, ExperimentStatus.CREATED.value], ] ), (os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file_3_parameters.cfg'), @@ -99,10 +99,10 @@ def test_create_table_if_not_exists(create_database_if_not_existing_mock, test_c [{'other_value': 5}], ['value,exponent,other_value,status,creation_date'], [ - [1, 3, 5, 'created'], - [1, 4, 5, 'created'], - [2, 3, 5, 'created'], - [2, 4, 5, 'created'], + [1, 3, 5, ExperimentStatus.CREATED.value], + [1, 4, 5, ExperimentStatus.CREATED.value], + [2, 3, 5, ExperimentStatus.CREATED.value], + [2, 4, 5, ExperimentStatus.CREATED.value], ] ), ] @@ -128,7 +128,7 @@ def test_fill_table( experiment_configuration = load_config(experiment_configuration_file_path) database_connector = DatabaseConnectorMYSQL( - experiment_configuration, + experiment_configuration, database_credential_file_path=os.path.join('test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg')) database_connector.fill_table(parameters, fixed_parameter_combination) args = write_to_database_mock.call_args_list @@ -141,3 +141,86 @@ def test_fill_table( assert datetime_from_string_argument.day == datetime.datetime.now().day assert datetime_from_string_argument.hour == datetime.datetime.now().hour assert datetime_from_string_argument.minute - datetime.datetime.now().minute <= 2 + + +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'connect') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'close_connection') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'cursor') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'execute') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'commit') +def test_delete_experiments_with_condition(commit_mock, execute_mock, cursor_mock, close_conenction_mock, connect_mock, create_database_if_not_existing_mock, _test_connection_mock): + create_database_if_not_existing_mock.return_value = None + _test_connection_mock.return_value = None + connect_mock.return_value = None + close_conenction_mock.return_value = None + execute_mock.return_value = None + cursor_mock.return_value = None + commit_mock.return_value = None + + experiment_configuration_file_path = load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg')) + database_connector = DatabaseConnectorMYSQL( + experiment_configuration_file_path, + database_credential_file_path=os.path.join( + 'test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg') + ) + + database_connector._delete_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"') + + args = execute_mock.call_args_list + assert len(args) == 1 + assert args[0][0][1] == f'DELETE FROM test_table WHERE status = "{ExperimentStatus.CREATED.value}"' + + +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'connect') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'cursor') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'execute') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'commit') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'fetchall') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'get_structure_from_table') +def test_get_experiments_with_condition(get_structture_from_table_mock, fetchall_mock, commit_mock, execute_mock, cursor_mock, connect_mock, create_database_if_not_existing_mock, _test_connection_mock): + create_database_if_not_existing_mock.return_value = None + _test_connection_mock.return_value = None + connect_mock.return_value = None + execute_mock.return_value = None + cursor_mock.return_value = None + commit_mock.return_value = None + fetchall_mock.return_value = [(1, 2,), ] + get_structture_from_table_mock.return_value = ['value', 'exponent'] + experiment_configuration_file_path = load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg')) + database_connector = DatabaseConnectorMYSQL( + experiment_configuration_file_path, + database_credential_file_path=os.path.join( + 'test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg') + ) + database_connector._get_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"') + + assert execute_mock.call_args_list[0][0][1] == f'SELECT * FROM test_table WHERE status = "{ExperimentStatus.CREATED.value}"' + + +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'connect') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'cursor') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'execute') +@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'commit') +def test_delete_table(commit_mock, execute_mock, cursor_mock, connect_mock, create_database_if_not_existing_mock, _test_connection_mock): + create_database_if_not_existing_mock.return_value = None + _test_connection_mock.return_value = None + connect_mock.return_value = None + execute_mock.return_value = None + cursor_mock.return_value = None + commit_mock.return_value = None + experiment_configuration_file_path = load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg')) + database_connector = DatabaseConnectorMYSQL( + experiment_configuration_file_path, + database_credential_file_path=os.path.join( + 'test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg') + ) + database_connector.delete_table() + + assert execute_mock.call_count == 1 + assert execute_mock.call_args[0][1] == 'DROP TABLE IF EXISTS test_table' diff --git a/test/test_run_experiments/test_run_mysql_experiment.py b/test/test_run_experiments/test_run_mysql_experiment.py index b4155a36..fa9d1a5a 100644 --- a/test/test_run_experiments/test_run_mysql_experiment.py +++ b/test/test_run_experiments/test_run_mysql_experiment.py @@ -31,24 +31,12 @@ def check_done_entries(experimenter, amount_of_entries): experimenter.dbconnector.close_connection(connection) -def delete_existing_table(experimenter): - connection = experimenter.dbconnector.connect() - cursor = experimenter.dbconnector.cursor(connection) - try: - cursor.execute(f"DROP TABLE {experimenter.dbconnector.table_name}") - experimenter.dbconnector.commit(connection) - experimenter.dbconnector.close_connection(connection) - except ProgrammingError as e: - experimenter.dbconnector.close_connection(connection) - logging.warning(e) - - def test_run_all_mqsql_experiments(): experiment_configuration_file_path = os.path.join('test', 'test_run_experiments', 'test_run_mysql_experiment_config.cfg') logging.basicConfig(level=logging.DEBUG) experimenter = PyExperimenter(experiment_configuration_file_path=experiment_configuration_file_path) try: - delete_existing_table(experimenter) + experimenter.delete_table() except ProgrammingError as e: logging.warning(e) experimenter.fill_table_from_config() diff --git a/test/test_run_experiments/test_run_sqlite_experiment.py b/test/test_run_experiments/test_run_sqlite_experiment.py index 71c5274e..b5b5317a 100644 --- a/test/test_run_experiments/test_run_sqlite_experiment.py +++ b/test/test_run_experiments/test_run_sqlite_experiment.py @@ -31,22 +31,16 @@ def check_done_entries(experimenter, amount_of_entries): experimenter.dbconnector.close_connection(connection) -def delete_existing_table(experimenter): - connection= experimenter.dbconnector.connect() - cursor= experimenter.dbconnector.cursor(connection) - try: - cursor.execute("DROP TABLE IF EXISTS test_table") - experimenter.dbconnector.commit(connection) - experimenter.dbconnector.close_connection(connection) - except ProgrammingError: - experimenter.dbconnector.close_connection(connection) - logging.warning("Table test_table does not exist") + def test_run_all_sqlite_experiments(): logging.basicConfig(level=logging.DEBUG) experimenter= PyExperimenter(experiment_configuration_file_path=os.path.join('test', 'test_run_experiments', 'test_run_sqlite_experiment_config.cfg')) - delete_existing_table(experimenter) + try: + experimenter.delete_table() + except Exception: + pass experimenter.fill_table_from_config() experimenter.execute(own_function, 1)