From 831f96c4208a0678f9baf34a25934df59b12c0da Mon Sep 17 00:00:00 2001 From: Lukas Fehring Date: Mon, 8 Jan 2024 13:11:48 +0100 Subject: [PATCH] Add n_jobs parameter to execute --- py_experimenter/experimenter.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/py_experimenter/experimenter.py b/py_experimenter/experimenter.py index 4dd294fe..1dba665d 100644 --- a/py_experimenter/experimenter.py +++ b/py_experimenter/experimenter.py @@ -2,7 +2,7 @@ import os import socket import traceback -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import pandas as pd from codecarbon import EmissionsTracker, OfflineEmissionsTracker @@ -167,7 +167,7 @@ def fill_table_from_combination(self, fixed_parameter_combinations: List[dict] = :type parameters: dict, optional :raises ParameterCombinationError: If any parameter of the combinations (rows) does not match the keyfields from the experiment configuration. - """ + """ rows = utils.combine_fill_table_parameters(self.config.database_configuration.keyfields.keys(), parameters, fixed_parameter_combinations) self.db_connector.create_table_if_not_existing() self.db_connector.fill_table(rows) @@ -187,7 +187,7 @@ def fill_table_from_config(self) -> None: error is raised. """ self.db_connector.create_table_if_not_existing() - parameters = self.config.experiments_configuration.experiment_configurations + parameters = self.config.experiments_configuration.experiment_configurations self.db_connector.fill_table(parameters) def fill_table_with_rows(self, rows: List[dict]) -> None: @@ -211,7 +211,9 @@ def fill_table_with_rows(self, rows: List[dict]) -> None: self.db_connector.create_table_if_not_existing() self.db_connector.fill_table(rows) - def execute(self, experiment_function: Callable[[Dict, Dict, ResultProcessor], None], max_experiments: int = -1) -> None: + def execute( + self, experiment_function: Callable[[Dict, Dict, ResultProcessor], None], max_experiments: int = -1, n_jobs: Optional[int] = None + ) -> None: """ Pulls open experiments from the database table and executes them. @@ -238,9 +240,12 @@ def execute(self, experiment_function: Callable[[Dict, Dict, ResultProcessor], N :type max_experiments: int, optional :raises InvalidValuesInConfiguration: If any value of the experiment parameters is of wrong data type. """ - with Parallel(n_jobs=self.config.n_jobs) as parallel: + if n_jobs is None: + n_jobs = self.config.n_jobs + + with Parallel(n_jobs=n_jobs) as parallel: if max_experiments == -1: - parallel(delayed(self._worker)(experiment_function) for _ in range(self.config.n_jobs)) + parallel(delayed(self._worker)(experiment_function) for _ in range(n_jobs)) else: parallel(delayed(self._execution_wrapper)(experiment_function) for _ in range(max_experiments)) self.logger.info("All configured executions finished.")