Skip to content

Commit

Permalink
Add n_jobs parameter to execute
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasFehring committed Jan 8, 2024
1 parent 6b5e575 commit 831f96c
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions py_experimenter/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import socket
import traceback
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
from codecarbon import EmissionsTracker, OfflineEmissionsTracker
Expand Down Expand Up @@ -167,7 +167,7 @@ def fill_table_from_combination(self, fixed_parameter_combinations: List[dict] =
:type parameters: dict, optional
:raises ParameterCombinationError: If any parameter of the combinations (rows) does not match the keyfields
from the experiment configuration.
"""
"""
rows = utils.combine_fill_table_parameters(self.config.database_configuration.keyfields.keys(), parameters, fixed_parameter_combinations)
self.db_connector.create_table_if_not_existing()
self.db_connector.fill_table(rows)
Expand All @@ -187,7 +187,7 @@ def fill_table_from_config(self) -> None:
error is raised.
"""
self.db_connector.create_table_if_not_existing()
parameters = self.config.experiments_configuration.experiment_configurations
parameters = self.config.experiments_configuration.experiment_configurations
self.db_connector.fill_table(parameters)

def fill_table_with_rows(self, rows: List[dict]) -> None:
Expand All @@ -211,7 +211,9 @@ def fill_table_with_rows(self, rows: List[dict]) -> None:
self.db_connector.create_table_if_not_existing()
self.db_connector.fill_table(rows)

def execute(self, experiment_function: Callable[[Dict, Dict, ResultProcessor], None], max_experiments: int = -1) -> None:
def execute(
self, experiment_function: Callable[[Dict, Dict, ResultProcessor], None], max_experiments: int = -1, n_jobs: Optional[int] = None
) -> None:
"""
Pulls open experiments from the database table and executes them.
Expand All @@ -238,9 +240,12 @@ def execute(self, experiment_function: Callable[[Dict, Dict, ResultProcessor], N
:type max_experiments: int, optional
:raises InvalidValuesInConfiguration: If any value of the experiment parameters is of wrong data type.
"""
with Parallel(n_jobs=self.config.n_jobs) as parallel:
if n_jobs is None:
n_jobs = self.config.n_jobs

with Parallel(n_jobs=n_jobs) as parallel:
if max_experiments == -1:
parallel(delayed(self._worker)(experiment_function) for _ in range(self.config.n_jobs))
parallel(delayed(self._worker)(experiment_function) for _ in range(n_jobs))
else:
parallel(delayed(self._execution_wrapper)(experiment_function) for _ in range(max_experiments))
self.logger.info("All configured executions finished.")
Expand Down

0 comments on commit 831f96c

Please sign in to comment.