Skip to content

Commit

Permalink
Fix slipped logging context when media rejected
Browse files Browse the repository at this point in the history
When a module rejects a piece of media we end up trying to close the
same logging context twice.

Instead of fixing the existing code we refactor to use an async context
manager, which is easier to write correctly.
  • Loading branch information
erikjohnston committed May 28, 2024
1 parent 9edb725 commit 4433521
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 91 deletions.
11 changes: 2 additions & 9 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ async def _download_remote_file(

file_info = FileInfo(server_name=server_name, file_id=file_id)

with self.media_storage.store_into_file(file_info) as (f, fname, finish):
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers = await self.client.download_media(
server_name,
Expand Down Expand Up @@ -693,8 +693,6 @@ async def _download_remote_file(
)
raise SynapseError(502, "Failed to fetch remote media")

await finish()

if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
Expand Down Expand Up @@ -1045,14 +1043,9 @@ async def _generate_thumbnails(
),
)

with self.media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()

Expand Down
102 changes: 37 additions & 65 deletions synapse/media/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
IO,
TYPE_CHECKING,
Any,
Awaitable,
AsyncIterator,
BinaryIO,
Callable,
Generator,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -97,11 +96,9 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str:
the file path written to in the primary media store
"""

with self.store_into_file(file_info) as (f, fname, finish_cb):
async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)
# Write to the other storage providers
await finish_cb()

return fname

Expand All @@ -111,32 +108,27 @@ async def write_to_file(self, source: IO, output: IO) -> None:
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)

@trace_with_opname("MediaStorage.store_into_file")
@contextlib.contextmanager
def store_into_file(
@contextlib.asynccontextmanager
async def store_into_file(
self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
) -> AsyncIterator[Tuple[BinaryIO, str]]:
"""Async Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns an awaitable.
Actually yields a 2-tuple (file, fname,), where file is a file
like object that can be written to and fname is the absolute path of file
on disk.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
finish_cb must be called and waited on after the file has been successfully been
written to. Should not be called if there was an error. Checks for spam and
stores the file into the configured storage providers.
Args:
file_info: Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
async with media_storage.store_into_file(info) as (f, fname,):
# .. write into f ...
await finish_cb()
"""

path = self._file_info_to_path(file_info)
Expand All @@ -145,62 +137,42 @@ def store_into_file(
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)

finished_called = [False]

main_media_repo_write_trace_scope = start_active_span(
"writing to main media repo"
)
main_media_repo_write_trace_scope.__enter__()

try:
with open(fname, "wb") as f:

async def finish() -> None:
# When someone calls finish, we assume they are done writing to the main media repo
main_media_repo_write_trace_scope.__exit__(None, None, None)

with start_active_span("writing to other storage providers"):
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()

spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])

for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)

finished_called[0] = True

yield f, fname, finish
except Exception as e:
with main_media_repo_write_trace_scope:
try:
main_media_repo_write_trace_scope.__exit__(
type(e), None, e.__traceback__
)
os.remove(fname)
except Exception:
pass
with open(fname, "wb") as f:
yield f, fname

raise e from None
except Exception as e:
try:
os.remove(fname)
except Exception:
pass

if not finished_called:
exc = Exception("Finished callback not called")
main_media_repo_write_trace_scope.__exit__(
type(exc), None, exc.__traceback__
raise e from None

with start_active_span("writing to other storage providers"):
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
)
raise exc
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])

for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)

async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
Expand Down
4 changes: 1 addition & 3 deletions synapse/media/url_previewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ async def _handle_url(

file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)

with self.media_storage.store_into_file(file_info) as (f, fname, finish):
async with self.media_storage.store_into_file(file_info) as (f, fname):
if url.startswith("data:"):
if not allow_data_urls:
raise SynapseError(
Expand All @@ -603,8 +603,6 @@ async def _handle_url(
else:
download_result = await self._download_url(url, f)

await finish()

try:
time_now_ms = self.clock.time_msec()

Expand Down
14 changes: 7 additions & 7 deletions tests/rest/client/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
f.write(SMALL_PNG)
self.get_success(finish())

media_storage = hs.get_media_repository().media_storage

ctx = media_storage.store_into_file(file_info)
(f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG)
self.get_success(ctx.__aexit__(None, None, None))

self.get_success(
self.store.store_cached_remote_media(
Expand Down
14 changes: 7 additions & 7 deletions tests/rest/media/test_domain_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
f.write(SMALL_PNG)
self.get_success(finish())

media_storage = hs.get_media_repository().media_storage

ctx = media_storage.store_into_file(file_info)
(f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG)
self.get_success(ctx.__aexit__(None, None, None))

self.get_success(
self.store.store_cached_remote_media(
Expand Down

0 comments on commit 4433521

Please sign in to comment.