From 4433521a25bc4d17cf18c262213f9eb8d6197943 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 28 May 2024 14:47:13 +0100 Subject: [PATCH] Fix slipped logging context when media rejected When a module rejects a piece of media we end up trying to close the same logging context twice. Instead of fixing the existing code we refactor to use an async context manager, which is easier to write correctly. --- synapse/media/media_repository.py | 11 +-- synapse/media/media_storage.py | 102 ++++++++--------------- synapse/media/url_previewer.py | 4 +- tests/rest/client/test_media.py | 14 ++-- tests/rest/media/test_domain_blocking.py | 14 ++-- 5 files changed, 54 insertions(+), 91 deletions(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 0e875132f6f..9da84959501 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -650,7 +650,7 @@ async def _download_remote_file( file_info = FileInfo(server_name=server_name, file_id=file_id) - with self.media_storage.store_into_file(file_info) as (f, fname, finish): + async with self.media_storage.store_into_file(file_info) as (f, fname): try: length, headers = await self.client.download_media( server_name, @@ -693,8 +693,6 @@ async def _download_remote_file( ) raise SynapseError(502, "Failed to fetch remote media") - await finish() - if b"Content-Type" in headers: media_type = headers[b"Content-Type"][0].decode("ascii") else: @@ -1045,14 +1043,9 @@ async def _generate_thumbnails( ), ) - with self.media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): + async with self.media_storage.store_into_file(file_info) as (f, fname): try: await self.media_storage.write_to_file(t_byte_source, f) - await finish() finally: t_byte_source.close() diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index b45b319f5c9..9979c48eaca 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -27,10 +27,9 @@ IO, TYPE_CHECKING, Any, - Awaitable, + AsyncIterator, BinaryIO, Callable, - Generator, Optional, Sequence, Tuple, @@ -97,11 +96,9 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str: the file path written to in the primary media store """ - with self.store_into_file(file_info) as (f, fname, finish_cb): + async with self.store_into_file(file_info) as (f, fname): # Write to the main media repository await self.write_to_file(source, f) - # Write to the other storage providers - await finish_cb() return fname @@ -111,32 +108,27 @@ async def write_to_file(self, source: IO, output: IO) -> None: await defer_to_thread(self.reactor, _write_file_synchronously, source, output) @trace_with_opname("MediaStorage.store_into_file") - @contextlib.contextmanager - def store_into_file( + @contextlib.asynccontextmanager + async def store_into_file( self, file_info: FileInfo - ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: - """Context manager used to get a file like object to write into, as + ) -> AsyncIterator[Tuple[BinaryIO, str]]: + """Async Context manager used to get a file like object to write into, as described by file_info. - Actually yields a 3-tuple (file, fname, finish_cb), where file is a file - like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns an awaitable. + Actually yields a 2-tuple (file, fname,), where file is a file + like object that can be written to and fname is the absolute path of file + on disk. fname can be used to read the contents from after upload, e.g. to generate thumbnails. - finish_cb must be called and waited on after the file has been successfully been - written to. Should not be called if there was an error. Checks for spam and - stores the file into the configured storage providers. - Args: file_info: Info about the file to store Example: - with media_storage.store_into_file(info) as (f, fname, finish_cb): + async with media_storage.store_into_file(info) as (f, fname,): # .. write into f ... - await finish_cb() """ path = self._file_info_to_path(file_info) @@ -145,62 +137,42 @@ def store_into_file( dirname = os.path.dirname(fname) os.makedirs(dirname, exist_ok=True) - finished_called = [False] - main_media_repo_write_trace_scope = start_active_span( "writing to main media repo" ) main_media_repo_write_trace_scope.__enter__() - try: - with open(fname, "wb") as f: - - async def finish() -> None: - # When someone calls finish, we assume they are done writing to the main media repo - main_media_repo_write_trace_scope.__exit__(None, None, None) - - with start_active_span("writing to other storage providers"): - # Ensure that all writes have been flushed and close the - # file. - f.flush() - f.close() - - spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info - ) - if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: - logger.info("Blocking media due to spam checker") - # Note that we'll delete the stored media, due to the - # try/except below. The media also won't be stored in - # the DB. - # We currently ignore any additional field returned by - # the spam-check API. - raise SpamMediaException(errcode=spam_check[0]) - - for provider in self.storage_providers: - with start_active_span(str(provider)): - await provider.store_file(path, file_info) - - finished_called[0] = True - - yield f, fname, finish - except Exception as e: + with main_media_repo_write_trace_scope: try: - main_media_repo_write_trace_scope.__exit__( - type(e), None, e.__traceback__ - ) - os.remove(fname) - except Exception: - pass + with open(fname, "wb") as f: + yield f, fname - raise e from None + except Exception as e: + try: + os.remove(fname) + except Exception: + pass - if not finished_called: - exc = Exception("Finished callback not called") - main_media_repo_write_trace_scope.__exit__( - type(exc), None, exc.__traceback__ + raise e from None + + with start_active_span("writing to other storage providers"): + spam_check = ( + await self._spam_checker_module_callbacks.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) ) - raise exc + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + with start_active_span(str(provider)): + await provider.store_file(path, file_info) async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 3897823b35e..2e65a047897 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -592,7 +592,7 @@ async def _handle_url( file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) - with self.media_storage.store_into_file(file_info) as (f, fname, finish): + async with self.media_storage.store_into_file(file_info) as (f, fname): if url.startswith("data:"): if not allow_data_urls: raise SynapseError( @@ -603,8 +603,6 @@ async def _handle_url( else: download_result = await self._download_url(url, f) - await finish() - try: time_now_ms = self.clock.time_msec() diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 600cbf89630..be4a289ec1b 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -93,13 +93,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # from a regular 404. file_id = "abcdefg12345" file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - with hs.get_media_repository().media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): - f.write(SMALL_PNG) - self.get_success(finish()) + + media_storage = hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) self.get_success( self.store.store_cached_remote_media( diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py index 88988f3a223..72205c6bb3b 100644 --- a/tests/rest/media/test_domain_blocking.py +++ b/tests/rest/media/test_domain_blocking.py @@ -44,13 +44,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # from a regular 404. file_id = "abcdefg12345" file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - with hs.get_media_repository().media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): - f.write(SMALL_PNG) - self.get_success(finish()) + + media_storage = hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) self.get_success( self.store.store_cached_remote_media(