From c719443647c9da1e6b01f73eafec3132fb49fd06 Mon Sep 17 00:00:00 2001 From: havok2063 Date: Wed, 18 Dec 2024 17:05:27 -0500 Subject: [PATCH] adding custom orjson encoder to cache, custom encoding memoryview objects --- python/valis/cache.py | 33 +++++++++++++++++++++++++++++++-- python/valis/routes/target.py | 5 +++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/python/valis/cache.py b/python/valis/cache.py index ba57087..e80c9d4 100644 --- a/python/valis/cache.py +++ b/python/valis/cache.py @@ -12,15 +12,18 @@ from __future__ import annotations +import base64 import hashlib import json import logging import re +import orjson from contextlib import asynccontextmanager from functools import wraps from inspect import Parameter, isawaitable, iscoroutinefunction from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, List, @@ -37,7 +40,7 @@ get_typed_return_annotation, get_typed_signature ) -from fastapi_cache import Backend, FastAPICache +from fastapi_cache import Backend, Coder, FastAPICache from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.backends.redis import RedisBackend from fastapi_cache.decorator import _augment_signature, _locate_param @@ -68,6 +71,32 @@ logger = logging.getLogger("uvicorn.error") +def bdefault(obj): + """ Custom encoder for orjson """ + # handle python memoryview objects + if isinstance(obj, memoryview): + return base64.b64encode(obj.tobytes()).decode() + raise TypeError + + +class ORJsonCoder(Coder): + """ Custom encoder class for the cache that uses orjson """ + + @classmethod + def encode(cls, value: Any) -> bytes: + """ serialization """ + return orjson.dumps( + value, + default=bdefault, + option=orjson.OPT_SERIALIZE_NUMPY, + ) + + @classmethod + def decode(cls, value: bytes) -> Any: + """ deserialization """ + return orjson.loads(value) + + @asynccontextmanager async def lifespan(_: FastAPI) -> AsyncIterator[None]: backend = settings.cache_backend @@ -132,7 +161,7 @@ async def valis_cache_key_builder( def valis_cache( expire: Optional[int] = settings.cache_ttl, - coder: Optional[Type[Coder]] = None, + coder: Optional[Type[Coder]] = ORJsonCoder, key_builder: Optional[KeyBuilder] = None, namespace: str = "valis-cache", injected_dependency_namespace: str = "__fastapi_cache", diff --git a/python/valis/routes/target.py b/python/valis/routes/target.py index 29a8bed..c5f39d8 100644 --- a/python/valis/routes/target.py +++ b/python/valis/routes/target.py @@ -188,7 +188,7 @@ async def get_spectrum(self, sdss_id: Annotated[int, Path(title="The sdss_id of product: Annotated[str, Query(description='The file species or data product name', example='specLite')], ext: Annotated[str, Query(description='For multi-extension spectra, e.g. mwmStar, the name of the spectral extension', example='BOSS/APO')] = None, ): - return get_a_spectrum(sdss_id, product, self.release, ext=ext) + return list(get_a_spectrum(sdss_id, product, self.release, ext=ext)) @router.get('/catalogs/{sdss_id}', summary='Retrieve catalog information for a target sdss_id', dependencies=[Depends(get_pw_db)], @@ -216,6 +216,7 @@ async def get_catalogs(self, sdss_id: int = Path(title="The sdss_id of the targe response_model=list[ParentCatalogModel], responses={400: {'description': 'Invalid input sdss_id or catalog'}}, summary='Retrieve parent catalog information for a taget by sdss_id') + @valis_cache(namespace='valis-target') async def get_parents(self, catalog: Annotated[str, Path(description='The parent catalog to search', example='gaia_dr3_source')], @@ -246,7 +247,7 @@ async def get_parents(self, @valis_cache(namespace='valis-target') async def get_cartons(self, sdss_id: int = Path(title="The sdss_id of the target to get", example=23326)): """ Return carton information for a given sdss_id """ - return get_target_cartons(sdss_id).dicts().iterator() + return list(get_target_cartons(sdss_id).dicts()) @router.get('/pipelines/{sdss_id}', summary='Retrieve pipeline data for a target sdss_id', dependencies=[Depends(get_pw_db)],