diff --git a/aiocache/decorators.py b/aiocache/decorators.py index d5cdac5d..e6fa3639 100644 --- a/aiocache/decorators.py +++ b/aiocache/decorators.py @@ -15,6 +15,10 @@ class cached: Caches the functions return value into a key generated with module_name, function_name and args. The cache is available in the function object as ``.cache``. + To invalidate the cache, you can use the ``invalidate_cache`` method of the function object by + passing the args that were used to generate the cache key as + ``await .invalidate_cache(*args, **kwargs)``. It is an async method. + In some cases you will need to send more args to configure the cache object. An example would be endpoint and port for the Redis cache. You can send those args as kwargs and they will be propagated accordingly. @@ -77,6 +81,7 @@ def __init__( self.alias = alias self.cache = None + self._func = None self._cache = cache self._serializer = serializer self._namespace = namespace @@ -84,10 +89,12 @@ def __init__( self._kwargs = kwargs def __call__(self, f): + self._func = f + if self.alias: self.cache = caches.get(self.alias) for arg in ("serializer", "namespace", "plugins"): - if getattr(self, f'_{arg}', None) is not None: + if getattr(self, f"_{arg}", None) is not None: logger.warning(f"Using cache alias; ignoring {arg!r} argument.") else: self.cache = _get_cache( @@ -103,6 +110,7 @@ async def wrapper(*args, **kwargs): return await self.decorator(f, *args, **kwargs) wrapper.cache = self.cache + wrapper.invalidate_cache = self.invalidate_cache return wrapper async def decorator( @@ -157,6 +165,10 @@ async def set_in_cache(self, key, value): except Exception: logger.exception("Couldn't set %s in key %s, unexpected error", value, key) + async def invalidate_cache(self, *args, **kwargs): + key = self.get_cache_key(self._func, args, kwargs) + return await self.cache.delete(key) + class cached_stampede(cached): """ @@ -330,7 +342,7 @@ def __call__(self, f): if self.alias: self.cache = caches.get(self.alias) for arg in ("serializer", "namespace", "plugins"): - if getattr(self, f'_{arg}', None) is not None: + if getattr(self, f"_{arg}", None) is not None: logger.warning(f"Using cache alias; ignoring {arg!r} argument.") else: self.cache = _get_cache( diff --git a/tests/ut/test_decorators.py b/tests/ut/test_decorators.py index cfa81e1b..8ba8bd20 100644 --- a/tests/ut/test_decorators.py +++ b/tests/ut/test_decorators.py @@ -233,6 +233,64 @@ async def bar(): assert foo.cache != bar.cache + async def test_invalidate_cache_exists(self): + @cached() + async def foo(): + """Dummy function.""" + + assert callable(foo.invalidate_cache) + + async def test_invalidate_cache(self): + cache_misses = 0 + + @cached(ttl=60 * 60) + async def foo(return_value: str): + nonlocal cache_misses + cache_misses += 1 + return return_value + + await foo("hello") # increments cache_misses since it's not cached + assert cache_misses == 1 + + await foo("hello") # doesn't increment cache_misses since it's cached + assert cache_misses == 1 + + await foo.invalidate_cache("hello") + await foo("hello") # increments cache_misses since the cache was invalidated + assert cache_misses == 2 + + await foo("hello") # doesn't increment cache_misses since it's cached + assert cache_misses == 2 + + async def test_invalidate_cache_diff_args(self): + """ + Tests that the invalidate_cache invalidates the cache for the correct arguments. + """ + + cache_misses = 0 + + @cached(ttl=60 * 60) + async def foo(return_value: str): + nonlocal cache_misses + cache_misses += 1 + return return_value + + await foo("hello") # increments cache_misses since "hello" is not cached + assert cache_misses == 1 + + await foo("world") # increments cache_misses since "world" is not cached + assert cache_misses == 2 + + await foo.invalidate_cache("world") + await foo("hello") # doesn't increment cache_misses since "hello" is still cached + await foo("hello") + await foo("hello") + await foo("hello") + assert cache_misses == 2 + + await foo("world") + assert cache_misses == 3 + class TestCachedStampede: @pytest.fixture @@ -476,8 +534,9 @@ async def test_cache_write_doesnt_wait_for_future(self, mocker, decorator, decor mocker.spy(decorator, "set_in_cache") with patch.object(decorator, "get_from_cache", autospec=True, return_value=[None, None]): with patch("aiocache.decorators.asyncio.ensure_future", autospec=True): - await decorator_call(1, keys=["a", "b"], value="value", - aiocache_wait_for_write=False) + await decorator_call( + 1, keys=["a", "b"], value="value", aiocache_wait_for_write=False + ) decorator.set_in_cache.assert_not_awaited() decorator.set_in_cache.assert_called_once_with({"a": ANY, "b": ANY}, stub_dict, ANY, ANY)