Skip to content

Commit

Permalink
adding custom orjson encoder to cache, custom encoding memoryview obj…
Browse files Browse the repository at this point in the history
…ects
  • Loading branch information
havok2063 committed Dec 18, 2024
1 parent 3b68b37 commit c719443
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
33 changes: 31 additions & 2 deletions python/valis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@

from __future__ import annotations

import base64
import hashlib
import json
import logging
import re
import orjson
from contextlib import asynccontextmanager
from functools import wraps
from inspect import Parameter, isawaitable, iscoroutinefunction
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Expand All @@ -37,7 +40,7 @@
get_typed_return_annotation,
get_typed_signature
)
from fastapi_cache import Backend, FastAPICache
from fastapi_cache import Backend, Coder, FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.backends.redis import RedisBackend
from fastapi_cache.decorator import _augment_signature, _locate_param
Expand Down Expand Up @@ -68,6 +71,32 @@
logger = logging.getLogger("uvicorn.error")


def bdefault(obj):
""" Custom encoder for orjson """
# handle python memoryview objects
if isinstance(obj, memoryview):
return base64.b64encode(obj.tobytes()).decode()
raise TypeError


class ORJsonCoder(Coder):
""" Custom encoder class for the cache that uses orjson """

@classmethod
def encode(cls, value: Any) -> bytes:
""" serialization """
return orjson.dumps(
value,
default=bdefault,
option=orjson.OPT_SERIALIZE_NUMPY,
)

@classmethod
def decode(cls, value: bytes) -> Any:
""" deserialization """
return orjson.loads(value)


@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
backend = settings.cache_backend
Expand Down Expand Up @@ -132,7 +161,7 @@ async def valis_cache_key_builder(

def valis_cache(
expire: Optional[int] = settings.cache_ttl,
coder: Optional[Type[Coder]] = None,
coder: Optional[Type[Coder]] = ORJsonCoder,
key_builder: Optional[KeyBuilder] = None,
namespace: str = "valis-cache",
injected_dependency_namespace: str = "__fastapi_cache",
Expand Down
5 changes: 3 additions & 2 deletions python/valis/routes/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async def get_spectrum(self, sdss_id: Annotated[int, Path(title="The sdss_id of
product: Annotated[str, Query(description='The file species or data product name', example='specLite')],
ext: Annotated[str, Query(description='For multi-extension spectra, e.g. mwmStar, the name of the spectral extension', example='BOSS/APO')] = None,
):
return get_a_spectrum(sdss_id, product, self.release, ext=ext)
return list(get_a_spectrum(sdss_id, product, self.release, ext=ext))

@router.get('/catalogs/{sdss_id}', summary='Retrieve catalog information for a target sdss_id',
dependencies=[Depends(get_pw_db)],
Expand Down Expand Up @@ -216,6 +216,7 @@ async def get_catalogs(self, sdss_id: int = Path(title="The sdss_id of the targe
response_model=list[ParentCatalogModel],
responses={400: {'description': 'Invalid input sdss_id or catalog'}},
summary='Retrieve parent catalog information for a taget by sdss_id')
@valis_cache(namespace='valis-target')
async def get_parents(self,
catalog: Annotated[str, Path(description='The parent catalog to search',
example='gaia_dr3_source')],
Expand Down Expand Up @@ -246,7 +247,7 @@ async def get_parents(self,
@valis_cache(namespace='valis-target')
async def get_cartons(self, sdss_id: int = Path(title="The sdss_id of the target to get", example=23326)):
""" Return carton information for a given sdss_id """
return get_target_cartons(sdss_id).dicts().iterator()
return list(get_target_cartons(sdss_id).dicts())

@router.get('/pipelines/{sdss_id}', summary='Retrieve pipeline data for a target sdss_id',
dependencies=[Depends(get_pw_db)],
Expand Down

0 comments on commit c719443

Please sign in to comment.