summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-11-09 11:00:30 -0500
committerGitHub <noreply@github.com>2023-11-09 11:00:30 -0500
commitff716b483b07b21de72d999250fdf9397003a914 (patch)
tree466db8d8885737346c41457cee9bf8fc549d387a
parentBulk-invalidate e2e cached queries after claiming keys (#16613) (diff)
downloadsynapse-ff716b483b07b21de72d999250fdf9397003a914.tar.xz
Return attrs for more media repo APIs. (#16611)
-rw-r--r--changelog.d/16611.misc1
-rw-r--r--synapse/handlers/profile.py15
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/media/media_repository.py70
-rw-r--r--synapse/media/url_previewer.py11
-rw-r--r--synapse/rest/media/thumbnail_resource.py16
-rw-r--r--synapse/storage/databases/main/media_repository.py105
-rw-r--r--tests/media/test_media_storage.py2
-rw-r--r--tests/rest/admin/test_media.py16
-rw-r--r--tests/rest/media/test_media_retention.py20
10 files changed, 148 insertions, 110 deletions
diff --git a/changelog.d/16611.misc b/changelog.d/16611.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16611.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c2109036ec..1027fbfd28 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Union
 
 from synapse.api.errors import (
     AuthError,
@@ -23,6 +23,7 @@ from synapse.api.errors import (
     StoreError,
     SynapseError,
 )
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
 from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -306,7 +307,9 @@ class ProfileHandler:
             server_name = host
 
         if self._is_mine_server_name(server_name):
-            media_info = await self.store.get_local_media(media_id)
+            media_info: Optional[
+                Union[LocalMedia, RemoteMedia]
+            ] = await self.store.get_local_media(media_id)
         else:
             media_info = await self.store.get_cached_remote_media(server_name, media_id)
 
@@ -322,12 +325,12 @@ class ProfileHandler:
 
         if self.max_avatar_size:
             # Ensure avatar does not exceed max allowed avatar size
-            if media_info["media_length"] > self.max_avatar_size:
+            if media_info.media_length > self.max_avatar_size:
                 logger.warning(
                     "Forbidding avatar change to %s: %d bytes is above the allowed size "
                     "limit",
                     mxc,
-                    media_info["media_length"],
+                    media_info.media_length,
                 )
                 return False
 
@@ -335,12 +338,12 @@ class ProfileHandler:
             # Ensure the avatar's file type is allowed
             if (
                 self.allowed_avatar_mimetypes
-                and media_info["media_type"] not in self.allowed_avatar_mimetypes
+                and media_info.media_type not in self.allowed_avatar_mimetypes
             ):
                 logger.warning(
                     "Forbidding avatar change to %s: mimetype %s not allowed",
                     mxc,
-                    media_info["media_type"],
+                    media_info.media_type,
                 )
                 return False
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 62f2454f5d..389dc5298a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -806,7 +806,7 @@ class SsoHandler:
                 media_id = profile["avatar_url"].split("/")[-1]
                 if self._is_mine_server_name(server_name):
                     media = await self._media_repo.store.get_local_media(media_id)
-                    if media is not None and upload_name == media["upload_name"]:
+                    if media is not None and upload_name == media.upload_name:
                         logger.info("skipping saving the user avatar")
                         return True
 
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..1957426c6a 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil
 from io import BytesIO
 from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
+import attr
 from matrix_common.types.mxc_uri import MXCUri
 
 import twisted.internet.error
@@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
 from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
 from synapse.media.url_previewer import UrlPreviewer
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.media_repository import RemoteMedia
 from synapse.types import UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -245,18 +247,18 @@ class MediaRepository:
             Resolves once a response has successfully been written to request
         """
         media_info = await self.store.get_local_media(media_id)
-        if not media_info or media_info["quarantined_by"]:
+        if not media_info or media_info.quarantined_by:
             respond_404(request)
             return
 
         self.mark_recently_accessed(None, media_id)
 
-        media_type = media_info["media_type"]
+        media_type = media_info.media_type
         if not media_type:
             media_type = "application/octet-stream"
-        media_length = media_info["media_length"]
-        upload_name = name if name else media_info["upload_name"]
-        url_cache = media_info["url_cache"]
+        media_length = media_info.media_length
+        upload_name = name if name else media_info.upload_name
+        url_cache = media_info.url_cache
 
         file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
 
@@ -310,16 +312,20 @@ class MediaRepository:
 
         # We deliberately stream the file outside the lock
         if responder:
-            media_type = media_info["media_type"]
-            media_length = media_info["media_length"]
-            upload_name = name if name else media_info["upload_name"]
+            upload_name = name if name else media_info.upload_name
             await respond_with_responder(
-                request, responder, media_type, media_length, upload_name
+                request,
+                responder,
+                media_info.media_type,
+                media_info.media_length,
+                upload_name,
             )
         else:
             respond_404(request)
 
-    async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+    async def get_remote_media_info(
+        self, server_name: str, media_id: str
+    ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
 
@@ -353,7 +359,7 @@ class MediaRepository:
 
     async def _get_remote_media_impl(
         self, server_name: str, media_id: str
-    ) -> Tuple[Optional[Responder], dict]:
+    ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
 
@@ -373,15 +379,17 @@ class MediaRepository:
 
         # If we have an entry in the DB, try and look for it
         if media_info:
-            file_id = media_info["filesystem_id"]
+            file_id = media_info.filesystem_id
             file_info = FileInfo(server_name, file_id)
 
-            if media_info["quarantined_by"]:
+            if media_info.quarantined_by:
                 logger.info("Media is quarantined")
                 raise NotFoundError()
 
-            if not media_info["media_type"]:
-                media_info["media_type"] = "application/octet-stream"
+            if not media_info.media_type:
+                media_info = attr.evolve(
+                    media_info, media_type="application/octet-stream"
+                )
 
             responder = await self.media_storage.fetch_media(file_info)
             if responder:
@@ -403,9 +411,9 @@ class MediaRepository:
             if not media_info:
                 raise e
 
-        file_id = media_info["filesystem_id"]
-        if not media_info["media_type"]:
-            media_info["media_type"] = "application/octet-stream"
+        file_id = media_info.filesystem_id
+        if not media_info.media_type:
+            media_info = attr.evolve(media_info, media_type="application/octet-stream")
         file_info = FileInfo(server_name, file_id)
 
         # We generate thumbnails even if another process downloaded the media
@@ -415,7 +423,7 @@ class MediaRepository:
         # otherwise they'll request thumbnails and get a 404 if they're not
         # ready yet.
         await self._generate_thumbnails(
-            server_name, media_id, file_id, media_info["media_type"]
+            server_name, media_id, file_id, media_info.media_type
         )
 
         responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +433,7 @@ class MediaRepository:
         self,
         server_name: str,
         media_id: str,
-    ) -> dict:
+    ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -518,7 +526,7 @@ class MediaRepository:
                 origin=server_name,
                 media_id=media_id,
                 media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
+                time_now_ms=time_now_ms,
                 upload_name=upload_name,
                 media_length=length,
                 filesystem_id=file_id,
@@ -526,15 +534,17 @@ class MediaRepository:
 
         logger.info("Stored remote media in file %r", fname)
 
-        media_info = {
-            "media_type": media_type,
-            "media_length": length,
-            "upload_name": upload_name,
-            "created_ts": time_now_ms,
-            "filesystem_id": file_id,
-        }
-
-        return media_info
+        return RemoteMedia(
+            media_origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            media_length=length,
+            upload_name=upload_name,
+            created_ts=time_now_ms,
+            filesystem_id=file_id,
+            last_access_ts=time_now_ms,
+            quarantined_by=None,
+        )
 
     def _get_thumbnail_requirements(
         self, media_type: str
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 9b5a3dd5f4..44aac21de6 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -240,15 +240,14 @@ class UrlPreviewer:
         cache_result = await self.store.get_url_cache(url, ts)
         if (
             cache_result
-            and cache_result["expires_ts"] > ts
-            and cache_result["response_code"] / 100 == 2
+            and cache_result.expires_ts > ts
+            and cache_result.response_code // 100 == 2
         ):
             # It may be stored as text in the database, not as bytes (such as
             # PostgreSQL). If so, encode it back before handing it on.
-            og = cache_result["og"]
-            if isinstance(og, str):
-                og = og.encode("utf8")
-            return og
+            if isinstance(cache_result.og, str):
+                return cache_result.og.encode("utf8")
+            return cache_result.og
 
         # If this URL can be accessed via an allowed oEmbed, use that instead.
         url_to_download = url
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 85b6bdbe72..efda8b4ab4 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
         if not media_info:
             respond_404(request)
             return
-        if media_info["quarantined_by"]:
+        if media_info.quarantined_by:
             logger.info("Media is quarantined")
             respond_404(request)
             return
@@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
             thumbnail_infos,
             media_id,
             media_id,
-            url_cache=bool(media_info["url_cache"]),
+            url_cache=bool(media_info.url_cache),
             server_name=None,
         )
 
@@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
         if not media_info:
             respond_404(request)
             return
-        if media_info["quarantined_by"]:
+        if media_info.quarantined_by:
             logger.info("Media is quarantined")
             respond_404(request)
             return
@@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
                 file_info = FileInfo(
                     server_name=None,
                     file_id=media_id,
-                    url_cache=media_info["url_cache"],
+                    url_cache=bool(media_info.url_cache),
                     thumbnail=info,
                 )
 
@@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
             desired_height,
             desired_method,
             desired_type,
-            url_cache=bool(media_info["url_cache"]),
+            url_cache=bool(media_info.url_cache),
         )
 
         if file_path:
@@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
             server_name, media_id
         )
 
-        file_id = media_info["filesystem_id"]
+        file_id = media_info.filesystem_id
 
         for info in thumbnail_infos:
             t_w = info.width == desired_width
@@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
             if t_w and t_h and t_method and t_type:
                 file_info = FileInfo(
                     server_name=server_name,
-                    file_id=media_info["filesystem_id"],
+                    file_id=file_id,
                     thumbnail=info,
                 )
 
@@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
             m_type,
             thumbnail_infos,
             media_id,
-            media_info["filesystem_id"],
+            media_info.filesystem_id,
             url_cache=False,
             server_name=server_name,
         )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c8d7c9fd32..7f99c64f1b 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@
 from enum import Enum
 from typing import (
     TYPE_CHECKING,
-    Any,
     Collection,
-    Dict,
     Iterable,
     List,
     Optional,
@@ -54,11 +52,32 @@ class LocalMedia:
     media_length: int
     upload_name: str
     created_ts: int
+    url_cache: Optional[str]
     last_access_ts: int
     quarantined_by: Optional[str]
     safe_from_quarantine: bool
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RemoteMedia:
+    media_origin: str
+    media_id: str
+    media_type: str
+    media_length: int
+    upload_name: Optional[str]
+    filesystem_id: str
+    created_ts: int
+    last_access_ts: int
+    quarantined_by: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UrlCache:
+    response_code: int
+    expires_ts: int
+    og: Union[str, bytes]
+
+
 class MediaSortOrder(Enum):
     """
     Enum to define the sorting method used when returning media with
@@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         super().__init__(database, db_conn, hs)
         self.server_name: str = hs.hostname
 
-    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
+    async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
         """Get the metadata for a local piece of media
 
         Returns:
             None if the media_id doesn't exist.
         """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "created_ts",
                 "quarantined_by",
                 "url_cache",
+                "last_access_ts",
                 "safe_from_quarantine",
             ),
             allow_none=True,
             desc="get_local_media",
         )
+        if row is None:
+            return None
+        return LocalMedia(media_id=media_id, **row)
 
     async def get_local_media_by_user_paginate(
         self,
@@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     media_length,
                     upload_name,
                     created_ts,
+                    url_cache,
                     last_access_ts,
                     quarantined_by,
                     safe_from_quarantine
@@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     media_length=row[2],
                     upload_name=row[3],
                     created_ts=row[4],
-                    last_access_ts=row[5],
-                    quarantined_by=row[6],
-                    safe_from_quarantine=bool(row[7]),
+                    url_cache=row[5],
+                    last_access_ts=row[6],
+                    quarantined_by=row[7],
+                    safe_from_quarantine=bool(row[8]),
                 )
                 for row in txn
             ]
@@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
-    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
+    async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
             None if the URL isn't cached.
         """
 
-        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
             # get the most recently cached result (relative to the given ts)
-            sql = (
-                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                " FROM local_media_repository_url_cache"
-                " WHERE url = ? AND download_ts <= ?"
-                " ORDER BY download_ts DESC LIMIT 1"
-            )
+            sql = """
+                SELECT response_code, expires_ts, og
+                FROM local_media_repository_url_cache
+                WHERE url = ? AND download_ts <= ?
+                ORDER BY download_ts DESC LIMIT 1
+            """
             txn.execute(sql, (url, ts))
             row = txn.fetchone()
 
             if not row:
                 # ...or if we've requested a timestamp older than the oldest
                 # copy in the cache, return the oldest copy (if any)
-                sql = (
-                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                    " FROM local_media_repository_url_cache"
-                    " WHERE url = ? AND download_ts > ?"
-                    " ORDER BY download_ts ASC LIMIT 1"
-                )
+                sql = """
+                    SELECT response_code, expires_ts, og
+                    FROM local_media_repository_url_cache
+                    WHERE url = ? AND download_ts > ?
+                    ORDER BY download_ts ASC LIMIT 1
+                """
                 txn.execute(sql, (url, ts))
                 row = txn.fetchone()
 
             if not row:
                 return None
 
-            return dict(
-                zip(
-                    (
-                        "response_code",
-                        "etag",
-                        "expires_ts",
-                        "og",
-                        "media_id",
-                        "download_ts",
-                    ),
-                    row,
-                )
-            )
+            return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
 
         return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
@@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         response_code: int,
         etag: Optional[str],
         expires_ts: int,
-        og: Optional[str],
+        og: str,
         media_id: str,
         download_ts: int,
     ) -> None:
@@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_cached_remote_media(
         self, origin: str, media_id: str
-    ) -> Optional[Dict[str, Any]]:
-        return await self.db_pool.simple_select_one(
+    ) -> Optional[RemoteMedia]:
+        row = await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "upload_name",
                 "created_ts",
                 "filesystem_id",
+                "last_access_ts",
                 "quarantined_by",
             ),
             allow_none=True,
             desc="get_cached_remote_media",
         )
+        if row is None:
+            return row
+        return RemoteMedia(media_origin=origin, media_id=media_id, **row)
 
     async def store_cached_remote_media(
         self,
@@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         t_width: int,
         t_height: int,
         t_type: str,
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[ThumbnailInfo]:
         """Fetch the thumbnail info of given width, height and type."""
 
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="remote_media_cache_thumbnails",
             keyvalues={
                 "media_origin": origin,
@@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "thumbnail_method",
                 "thumbnail_type",
                 "thumbnail_length",
-                "filesystem_id",
             ),
             allow_none=True,
             desc="get_remote_media_thumbnail",
         )
+        if row is None:
+            return None
+        return ThumbnailInfo(
+            width=row["thumbnail_width"],
+            height=row["thumbnail_height"],
+            method=row["thumbnail_method"],
+            type=row["thumbnail_type"],
+            length=row["thumbnail_length"],
+        )
 
     @trace
     async def store_remote_media_thumbnail(
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 15f5d644e4..a8e7a76b29 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         origin, media_id = self.media_id.split("/")
         info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
         assert info is not None
-        file_id = info["filesystem_id"]
+        file_id = info.filesystem_id
 
         thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
             origin, file_id
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 278808abb5..dac79bd745 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
         # quarantining
         channel = self.make_request(
@@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["quarantined_by"])
+        self.assertTrue(media_info.quarantined_by)
 
         # remove from quarantine
         channel = self.make_request(
@@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
     def test_quarantine_protected_media(self) -> None:
         """
@@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
         # verify protection
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["safe_from_quarantine"])
+        self.assertTrue(media_info.safe_from_quarantine)
 
         # quarantining
         channel = self.make_request(
@@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
         # verify that is not in quarantine
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
 
 class ProtectMediaByIDTestCase(_AdminMediaTests):
@@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["safe_from_quarantine"])
+        self.assertFalse(media_info.safe_from_quarantine)
 
         # protect
         channel = self.make_request(
@@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["safe_from_quarantine"])
+        self.assertTrue(media_info.safe_from_quarantine)
 
         # unprotect
         channel = self.make_request(
@@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["safe_from_quarantine"])
+        self.assertFalse(media_info.safe_from_quarantine)
 
 
 class PurgeMediaCacheTestCase(_AdminMediaTests):
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index b59d9dfd4d..27a663a23b 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
         def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
             """Given an MXC URI, assert whether it has been purged or not."""
             if mxc_uri.server_name == self.hs.config.server.server_name:
-                found_media_dict = self.get_success(
-                    self.store.get_local_media(mxc_uri.media_id)
+                found_media = bool(
+                    self.get_success(self.store.get_local_media(mxc_uri.media_id))
                 )
             else:
-                found_media_dict = self.get_success(
-                    self.store.get_cached_remote_media(
-                        mxc_uri.server_name, mxc_uri.media_id
+                found_media = bool(
+                    self.get_success(
+                        self.store.get_cached_remote_media(
+                            mxc_uri.server_name, mxc_uri.media_id
+                        )
                     )
                 )
 
             if expect_purged:
-                self.assertIsNone(
-                    found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
-                )
+                self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
             else:
-                self.assertIsNotNone(
-                    found_media_dict,
+                self.assertTrue(
+                    found_media,
                     msg=f"{mxc_uri} unexpectedly purged",
                 )