Skip to content

Commit

Permalink
refresh cohorts based on flag configs in storage
Browse files Browse the repository at this point in the history
  • Loading branch information
tyiuhc committed Aug 6, 2024
1 parent 7043732 commit 1d974f1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 26 deletions.
28 changes: 13 additions & 15 deletions src/amplitude_experiment/cohort/cohort_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .cohort import Cohort
from .cohort_download_api import CohortDownloadApi
from .cohort_storage import CohortStorage
from ..exception import CohortUpdateException
from ..exception import CohortsDownloadException


class CohortLoader:
Expand All @@ -30,39 +30,37 @@ def load_cohort(self, cohort_id: str) -> Future:

def _remove_job(self, cohort_id: str):
if cohort_id in self.jobs:
del self.jobs[cohort_id]
with self.lock_jobs:
self.jobs.pop(cohort_id, None)

def download_cohort(self, cohort_id: str) -> Cohort:
cohort = self.cohort_storage.get_cohort(cohort_id)
return self.cohort_download_api.get_cohort(cohort_id, cohort)

def update_stored_cohorts(self) -> Future:
def update_task():
def download_cohorts(self, cohort_ids: Set[str]) -> Future:
def update_task(task_cohort_ids):
errors = []
cohort_ids = self.cohort_storage.get_cohort_ids()

futures = []
with self.lock_jobs:
for cohort_id in cohort_ids:
future = self.load_cohort(cohort_id)
futures.append(future)
for cohort_id in task_cohort_ids:
future = self.load_cohort(cohort_id)
futures.append(future)

for future in as_completed(futures):
cohort_id = next(c_id for c_id, f in self.jobs.items() if f == future)
try:
future.result()
except Exception as e:
errors.append((cohort_id, e))
cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None)
if cohort_id:
errors.append((cohort_id, e))

if errors:
raise CohortUpdateException(errors)
raise CohortsDownloadException(errors)

return self.executor.submit(update_task)
return self.executor.submit(update_task, cohort_ids)

def __load_cohort_internal(self, cohort_id):
try:
cohort = self.download_cohort(cohort_id)
# None is returned when cohort is not modified
if cohort is not None:
self.cohort_storage.put_cohort(cohort)
except Exception as e:
Expand Down
20 changes: 10 additions & 10 deletions src/amplitude_experiment/deployment/deployment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from ..flag.flag_config_api import FlagConfigApi
from ..flag.flag_config_storage import FlagConfigStorage
from ..local.poller import Poller
from ..util.flag_config import get_all_cohort_ids_from_flag
from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags

COHORT_POLLING_INTERVAL_MILLIS = 60000


class DeploymentRunner:
Expand All @@ -29,7 +31,7 @@ def __init__(
self.lock = threading.Lock()
self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update)
if self.cohort_loader:
self.cohort_poller = Poller(self.config.flag_config_polling_interval_millis / 1000,
self.cohort_poller = Poller(COHORT_POLLING_INTERVAL_MILLIS / 1000,
self.__update_cohorts)
self.logger = logger

Expand Down Expand Up @@ -71,15 +73,12 @@ def __update_flag_configs(self):

existing_cohort_ids = self.cohort_storage.get_cohort_ids()
cohort_ids_to_download = new_cohort_ids - existing_cohort_ids
cohort_download_errors = []

# download all new cohorts
for cohort_id in cohort_ids_to_download:
try:
self.cohort_loader.load_cohort(cohort_id).result()
except Exception as e:
cohort_download_errors.append((cohort_id, str(e)))
self.logger.warning(f"Download cohort {cohort_id} failed: {e}")
try:
self.cohort_loader.download_cohorts(cohort_ids_to_download).result()
except Exception as e:
self.logger.warning(f"Error while downloading cohorts: {e}")

# get updated set of cohort ids
updated_cohort_ids = self.cohort_storage.get_cohort_ids()
Expand All @@ -97,8 +96,9 @@ def __update_flag_configs(self):
self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.")

def __update_cohorts(self):
cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values()))
try:
self.cohort_loader.update_stored_cohorts().result()
self.cohort_loader.download_cohorts(cohort_ids).result()
except Exception as e:
self.logger.warning(f"Error while updating cohorts: {e}")

Expand Down
2 changes: 1 addition & 1 deletion src/amplitude_experiment/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, status_code, message):
self.status_code = status_code


class CohortUpdateException(Exception):
class CohortsDownloadException(Exception):
def __init__(self, errors):
self.errors = errors
super().__init__(self.__str__())
Expand Down

0 comments on commit 1d974f1

Please sign in to comment.