Skip to content

Commit

Permalink
Modify functionality to reset experiments (#77)
Browse files Browse the repository at this point in the history
* Add Enum `ExperimentStatus` 
* Modify `experimenter.reset_experiments()` to be able to get
  * single `ExperimentStatus` 
  * list of `ExperimentStatus` 
  * `"all"` to reset all `ExperimentStatus` 
* Add method `experimenter.delete_table()` 
* Add method `delete_table()` 
* Increase readability of SQL statements
  • Loading branch information
LukasFehring authored Nov 21, 2022
1 parent 37d3ed8 commit 2ff9f9a
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 74 deletions.
69 changes: 51 additions & 18 deletions py_experimenter/database_connector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import abc
import logging
from datetime import datetime
from typing import List, Tuple
from typing import List, Optional, Tuple

import pandas as pd

from py_experimenter import utils
from py_experimenter.exceptions import DatabaseConnectionError, EmptyFillDatabaseCallError, NoExperimentsLeftException, TableHasWrongStructureError
from py_experimenter.experiment_status import ExperimentStatus


class DatabaseConnector(abc.ABC):
Expand Down Expand Up @@ -149,7 +150,7 @@ def fill_table(self, parameters=None, fixed_parameter_combinations=None) -> None
if self._check_combination_in_existing_rows(combination, existing_rows, keyfield_names):
continue
values = list(combination.values())
values.append("created")
values.append(ExperimentStatus.CREATED.value)
values.append("%s" % time.strftime("%m/%d/%Y, %H:%M:%S"))

self._write_to_database(column_names.split(', '), values)
Expand All @@ -175,19 +176,17 @@ def get_experiment_configuration(self, random_order: bool):

return experiment_id, dict(zip([i[0] for i in description], *values))

def _execute_queries(self, connection, cursor, random_order) ->Tuple[int, List, List]:
def _execute_queries(self, connection, cursor, random_order) -> Tuple[int, List, List]:
if random_order:
order_by = self.__class__.random_order_string()
else:
order_by = "id"
select_experiment = f"SELECT id FROM {self.table_name} WHERE status = 'created' ORDER BY {order_by} LIMIT 1;"
alter_experiment = "UPDATE {} SET status = 'running' WHERE id = {};"
select_keyfields = "SELECT {} FROM {} WHERE id = {};"
self.execute(cursor, select_experiment)

self.execute(cursor, f"SELECT id FROM {self.table_name} WHERE status = 'created' ORDER BY {order_by} LIMIT 1;")
experiment_id = self.fetchall(cursor)[0][0]
self.execute(cursor, alter_experiment.format(self.table_name, experiment_id))
self.execute(cursor, select_keyfields.format(','.join(utils.get_keyfield_names(
self.database_credential_file_path)), self.table_name, experiment_id))
self.execute(cursor, f"UPDATE {self.table_name} SET status = '{ExperimentStatus.RUNNING.value}' WHERE id = {experiment_id};")
keyfields = ','.join(utils.get_keyfield_names(self.database_credential_file_path))
self.execute(cursor, f"SELECT {keyfields} FROM {self.table_name} WHERE id = {experiment_id};")
values = self.fetchall(cursor)
self.commit(connection)
description = cursor.description
Expand All @@ -196,7 +195,7 @@ def _execute_queries(self, connection, cursor, random_order) ->Tuple[int, List,
@abc.abstractstaticmethod
def random_order_string():
pass

@abc.abstractmethod
def _pull_open_experiment(self, random_order) -> Tuple[int, List, List]:
pass
Expand Down Expand Up @@ -252,7 +251,28 @@ def not_executed_yet(self, where) -> bool:
connection.close()
return not_executed

def delete_experiments_with_status(self, status):
def reset_experiments(self, *states: str) -> None:
def get_dict_for_keyfields_and_rows(keyfields: List[str], rows: List[List[str]]) -> List[dict]:
return [{key: value for key, value in zip(keyfields, row)} for row in rows]

for state in states:
keyfields, rows = self._pop_experiments_with_status(state)
rows = get_dict_for_keyfields_and_rows(keyfields, rows)
if rows:
self.fill_table(fixed_parameter_combinations=rows)
logging.info(f"{len(rows)} experiments with status {' '.join(list(states))} were reset")

def _pop_experiments_with_status(self, status: Optional[str] = None) -> Tuple[List[str], List[List]]:
if status == ExperimentStatus.ALL.value:
condition = None
else:
condition = f"WHERE status = '{status}'"

column_names, entries = self._get_experiments_with_condition(condition)
self._delete_experiments_with_condition(condition)
return column_names, entries

def _get_experiments_with_condition(self, condition: Optional[str] = None) -> Tuple[List[str], List[List]]:
def _get_keyfields_from_columns(column_names, entries):
df = pd.DataFrame(entries, columns=column_names)
keyfields = utils.get_keyfield_names(self.database_credential_file_path)
Expand All @@ -261,21 +281,34 @@ def _get_keyfields_from_columns(column_names, entries):

connection = self.connect()
cursor = self.cursor(connection)
column_names = self.get_structure_from_table(cursor)

self.execute(cursor, f"SELECT * FROM {self.table_name} WHERE status='{status}'")
entries = cursor.fetchall()
query_condition= condition or ''
self.execute(cursor, f"SELECT * FROM {self.table_name} {query_condition}")
entries = self.fetchall(cursor)
column_names = self.get_structure_from_table(cursor)
column_names, entries = _get_keyfields_from_columns(column_names, entries)
self.execute(cursor, f"DELETE FROM {self.table_name} WHERE status='{status}'")
self.commit(connection)
self.close_connection(connection)

return column_names, entries

def _delete_experiments_with_condition(self, condition: Optional[str] = None) -> None:
connection = self.connect()
cursor = self.cursor(connection)

query_condition = condition or ''
self.execute(cursor, f'DELETE FROM {self.table_name} {query_condition}')
self.commit(connection)
self.close_connection(connection)

@abc.abstractmethod
def get_structure_from_table(self, cursor):
pass

def delete_table(self) -> None:
connection = self.connect()
cursor = self.cursor(connection)
self.execute(cursor, f'DROP TABLE IF EXISTS {self.table_name}')
self.commit(connection)

def get_table(self) -> pd.DataFrame:
connection = self.connect()
query = f"SELECT * FROM {self.table_name}"
Expand Down
9 changes: 9 additions & 0 deletions py_experimenter/experiment_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Enum


class ExperimentStatus(Enum):
CREATED = 'created'
RUNNING = 'running'
DONE = 'done'
ERROR = 'error'
ALL = 'all'
35 changes: 17 additions & 18 deletions py_experimenter/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from py_experimenter.database_connector_lite import DatabaseConnectorLITE
from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL
from py_experimenter.exceptions import InvalidConfigError, InvalidValuesInConfiguration, NoExperimentsLeftException
from py_experimenter.experiment_status import ExperimentStatus
from py_experimenter.result_processor import ResultProcessor


Expand Down Expand Up @@ -404,32 +405,30 @@ def _execution_wrapper(self,
error_msg = traceback.format_exc()
logging.error(error_msg)
result_processor._write_error(error_msg)
result_processor._change_status('error')
result_processor._change_status(ExperimentStatus.ERROR.value)
else:
result_processor._change_status('done')
result_processor._change_status(ExperimentStatus.DONE.value)

def reset_experiments(self, status) -> None:
def reset_experiments(self, *states: str) -> None:
"""
Deletes the experiments of the database table having the given `status`. Afterwards, all rows that have been
deleted from the database table are added to the table again featuring `created` as a status. Experiments
to reset can be selected based on the following status:
to reset can be selected based on the following status definition:
* `created` when the experiment is added to the database table, execution has not started.
* `running` when the execution of the experiment has been started.
* `error` if something went wrong during the execution, i.e., an exception is raised
* `done` if the execution finished successfully.
:param status: The status of experiments that should be reset.
:type status: str
:param status: The status of experiments that should be reset. Either `created`, `running`, `error`, `done`, or `all`.
Note that `states` is a variable length argument, so multiple states can be given as a list.
:type status: str
"""
def get_dict_for_keyfields_and_rows(keyfields: List[str], rows: List[List[str]]) -> List[dict]:
return [{key: value for key, value in zip(keyfields, row)} for row in rows]
if not states:
logging.warning('No states given to reset experiments. No experiments are reset.')
else:
self.dbconnector.reset_experiments(*states)

keyfields, rows = self.dbconnector.delete_experiments_with_status(status)
rows = get_dict_for_keyfields_and_rows(keyfields, rows)
if rows:
self.fill_table_with_rows(rows)
logging.info(f"{len(rows)} experiments with status {status} were reset")
def delete_table(self) -> None:
"""
Drops the table defined in the configuration file.
"""
self.dbconnector.delete_table()

def get_table(self) -> pd.DataFrame:
"""
Expand Down
111 changes: 97 additions & 14 deletions test/test_database_connector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import datetime
import os

Expand All @@ -9,6 +8,7 @@
from py_experimenter import database_connector, database_connector_mysql
from py_experimenter.database_connector import DatabaseConnector
from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL
from py_experimenter.experiment_status import ExperimentStatus
from py_experimenter.utils import load_config


Expand Down Expand Up @@ -72,37 +72,37 @@ def test_create_table_if_not_exists(create_database_if_not_existing_mock, test_c
[],
['value,exponent,status,creation_date'],
[
[1, 3, 'created'],
[1, 4, 'created'],
[2, 3, 'created'],
[2, 4, 'created']
[1, 3, ExperimentStatus.CREATED.value],
[1, 4, ExperimentStatus.CREATED.value],
[2, 3, ExperimentStatus.CREATED.value],
[2, 4, ExperimentStatus.CREATED.value]
]),
(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'),
{},
[{'value': 1, 'exponent': 3}, {'value': 1, 'exponent': 4}],
['value,exponent,status,creation_date'],
[
[1, 3, 'created'],
[1, 4, 'created'],
[1, 3, ExperimentStatus.CREATED.value],
[1, 4, ExperimentStatus.CREATED.value],
]),
(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file_3_parameters.cfg'),
{'value': [1, 2], },
[{'exponent': 3, 'other_value': 5}],
['value,exponent,other_value,status,creation_date'],
[
[1, 3, 5, 'created'],
[2, 3, 5, 'created'],
[1, 3, 5, ExperimentStatus.CREATED.value],
[2, 3, 5, ExperimentStatus.CREATED.value],
]
),
(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file_3_parameters.cfg'),
{'value': [1, 2], 'exponent': [3, 4], },
[{'other_value': 5}],
['value,exponent,other_value,status,creation_date'],
[
[1, 3, 5, 'created'],
[1, 4, 5, 'created'],
[2, 3, 5, 'created'],
[2, 4, 5, 'created'],
[1, 3, 5, ExperimentStatus.CREATED.value],
[1, 4, 5, ExperimentStatus.CREATED.value],
[2, 3, 5, ExperimentStatus.CREATED.value],
[2, 4, 5, ExperimentStatus.CREATED.value],
]
),
]
Expand All @@ -128,7 +128,7 @@ def test_fill_table(

experiment_configuration = load_config(experiment_configuration_file_path)
database_connector = DatabaseConnectorMYSQL(
experiment_configuration,
experiment_configuration,
database_credential_file_path=os.path.join('test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg'))
database_connector.fill_table(parameters, fixed_parameter_combination)
args = write_to_database_mock.call_args_list
Expand All @@ -141,3 +141,86 @@ def test_fill_table(
assert datetime_from_string_argument.day == datetime.datetime.now().day
assert datetime_from_string_argument.hour == datetime.datetime.now().hour
assert datetime_from_string_argument.minute - datetime.datetime.now().minute <= 2


@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'connect')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'close_connection')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'cursor')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'execute')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'commit')
def test_delete_experiments_with_condition(commit_mock, execute_mock, cursor_mock, close_conenction_mock, connect_mock, create_database_if_not_existing_mock, _test_connection_mock):
create_database_if_not_existing_mock.return_value = None
_test_connection_mock.return_value = None
connect_mock.return_value = None
close_conenction_mock.return_value = None
execute_mock.return_value = None
cursor_mock.return_value = None
commit_mock.return_value = None

experiment_configuration_file_path = load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'))
database_connector = DatabaseConnectorMYSQL(
experiment_configuration_file_path,
database_credential_file_path=os.path.join(
'test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg')
)

database_connector._delete_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"')

args = execute_mock.call_args_list
assert len(args) == 1
assert args[0][0][1] == f'DELETE FROM test_table WHERE status = "{ExperimentStatus.CREATED.value}"'


@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'connect')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'cursor')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'execute')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'commit')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'fetchall')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'get_structure_from_table')
def test_get_experiments_with_condition(get_structture_from_table_mock, fetchall_mock, commit_mock, execute_mock, cursor_mock, connect_mock, create_database_if_not_existing_mock, _test_connection_mock):
create_database_if_not_existing_mock.return_value = None
_test_connection_mock.return_value = None
connect_mock.return_value = None
execute_mock.return_value = None
cursor_mock.return_value = None
commit_mock.return_value = None
fetchall_mock.return_value = [(1, 2,), ]
get_structture_from_table_mock.return_value = ['value', 'exponent']
experiment_configuration_file_path = load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'))
database_connector = DatabaseConnectorMYSQL(
experiment_configuration_file_path,
database_credential_file_path=os.path.join(
'test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg')
)
database_connector._get_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"')

assert execute_mock.call_args_list[0][0][1] == f'SELECT * FROM test_table WHERE status = "{ExperimentStatus.CREATED.value}"'


@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_test_connection')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, '_create_database_if_not_existing')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'connect')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'cursor')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'execute')
@patch.object(database_connector_mysql.DatabaseConnectorMYSQL, 'commit')
def test_delete_table(commit_mock, execute_mock, cursor_mock, connect_mock, create_database_if_not_existing_mock, _test_connection_mock):
create_database_if_not_existing_mock.return_value = None
_test_connection_mock.return_value = None
connect_mock.return_value = None
execute_mock.return_value = None
cursor_mock.return_value = None
commit_mock.return_value = None
experiment_configuration_file_path = load_config(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'))
database_connector = DatabaseConnectorMYSQL(
experiment_configuration_file_path,
database_credential_file_path=os.path.join(
'test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg')
)
database_connector.delete_table()

assert execute_mock.call_count == 1
assert execute_mock.call_args[0][1] == 'DROP TABLE IF EXISTS test_table'
14 changes: 1 addition & 13 deletions test/test_run_experiments/test_run_mysql_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,12 @@ def check_done_entries(experimenter, amount_of_entries):
experimenter.dbconnector.close_connection(connection)


def delete_existing_table(experimenter):
connection = experimenter.dbconnector.connect()
cursor = experimenter.dbconnector.cursor(connection)
try:
cursor.execute(f"DROP TABLE {experimenter.dbconnector.table_name}")
experimenter.dbconnector.commit(connection)
experimenter.dbconnector.close_connection(connection)
except ProgrammingError as e:
experimenter.dbconnector.close_connection(connection)
logging.warning(e)


def test_run_all_mqsql_experiments():
experiment_configuration_file_path = os.path.join('test', 'test_run_experiments', 'test_run_mysql_experiment_config.cfg')
logging.basicConfig(level=logging.DEBUG)
experimenter = PyExperimenter(experiment_configuration_file_path=experiment_configuration_file_path)
try:
delete_existing_table(experimenter)
experimenter.delete_table()
except ProgrammingError as e:
logging.warning(e)
experimenter.fill_table_from_config()
Expand Down
Loading

0 comments on commit 2ff9f9a

Please sign in to comment.