diff --git a/src/curate_gpt/store/chromadb_adapter.py b/src/curate_gpt/store/chromadb_adapter.py index d88fa8b..63ccb52 100644 --- a/src/curate_gpt/store/chromadb_adapter.py +++ b/src/curate_gpt/store/chromadb_adapter.py @@ -17,6 +17,8 @@ from linkml_runtime.utils.yamlutils import YAMLRoot from oaklib.utilities.iterator_utils import chunk from pydantic import BaseModel +from requests.exceptions import ConnectionError, RequestException + from curate_gpt.store.db_adapter import ( OBJECT, @@ -207,12 +209,44 @@ def _insert_or_update( logger.info("Preparing ids...") ids = [self._id(o, id_field) for o in next_objs] logger.info(f"Inserting {len(next_objs)} / {num_objs} objects into {collection}") - method = getattr(collection_obj, method_name) - method( - documents=docs, - metadatas=metadatas, - ids=ids, - ) + try: + self.exponential_backoff_request( + lambda: getattr(collection_obj, method_name)(documents=docs, metadatas=metadatas, ids=ids)) + except Exception as e: + logger.error(f"Failed to process batch after retries: {e}, Error Type: {type(e).__name__}") + break + + def exponential_backoff_request(self, method, *args, max_retries=5, initial_wait_time=2, max_wait_time=300, + **kwargs): + """ + Executes a request with exponential backoff strategy. + + :param method: The request method to execute. + :param max_retries: Maximum number of retries (default 100). + :param initial_wait_time: Initial wait time in seconds (default 2). + :param max_wait_time: Maximum wait time in seconds (default 300). + :return: Response from the method, or None if max retries reached. + """ + wait_time = initial_wait_time + + for attempt in range(max_retries): + try: + logger.info(f"Attempt {attempt + 1}/{max_retries} for batch operation.") + return method(*args, **kwargs) + except (ConnectionError, RequestException, Exception) as e: + error_message = str(e) + if "server is overloaded" in error_message or "500" in error_message or \ + "The server had an error while processing your request" in error_message: + logger.warning( + f"Attempt {attempt + 1}: Error encountered - {error_message}. Retrying in {wait_time} seconds.") + time.sleep(wait_time) + wait_time = min(wait_time * 2, max_wait_time) + if attempt == max_retries - 1: + logger.error(f"Failed to process batch after {max_retries} retries.") + else: + logger.error(f"Non-retriable error encountered: {error_message}, Error Type: {type(e).__name__}.") + break + return None def update(self, objs: Union[OBJECT, List[OBJECT]], **kwargs): """