diff --git a/changelog.d/16611.misc b/changelog.d/16611.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16611.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c2109036ec..1027fbfd28 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Union
from synapse.api.errors import (
AuthError,
@@ -23,6 +23,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -306,7 +307,9 @@ class ProfileHandler:
server_name = host
if self._is_mine_server_name(server_name):
- media_info = await self.store.get_local_media(media_id)
+ media_info: Optional[
+ Union[LocalMedia, RemoteMedia]
+ ] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -322,12 +325,12 @@ class ProfileHandler:
if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size
- if media_info["media_length"] > self.max_avatar_size:
+ if media_info.media_length > self.max_avatar_size:
logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit",
mxc,
- media_info["media_length"],
+ media_info.media_length,
)
return False
@@ -335,12 +338,12 @@ class ProfileHandler:
# Ensure the avatar's file type is allowed
if (
self.allowed_avatar_mimetypes
- and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ and media_info.media_type not in self.allowed_avatar_mimetypes
):
logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed",
mxc,
- media_info["media_type"],
+ media_info.media_type,
)
return False
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 62f2454f5d..389dc5298a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -806,7 +806,7 @@ class SsoHandler:
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
- if media is not None and upload_name == media["upload_name"]:
+ if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..1957426c6a 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+import attr
from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error
@@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -245,18 +247,18 @@ class MediaRepository:
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info or media_info.quarantined_by:
respond_404(request)
return
self.mark_recently_accessed(None, media_id)
- media_type = media_info["media_type"]
+ media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- url_cache = media_info["url_cache"]
+ media_length = media_info.media_length
+ upload_name = name if name else media_info.upload_name
+ url_cache = media_info.url_cache
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
@@ -310,16 +312,20 @@ class MediaRepository:
# We deliberately stream the file outside the lock
if responder:
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
+ upload_name = name if name else media_info.upload_name
await respond_with_responder(
- request, responder, media_type, media_length, upload_name
+ request,
+ responder,
+ media_info.media_type,
+ media_info.media_length,
+ upload_name,
)
else:
respond_404(request)
- async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+ async def get_remote_media_info(
+ self, server_name: str, media_id: str
+ ) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -353,7 +359,7 @@ class MediaRepository:
async def _get_remote_media_impl(
self, server_name: str, media_id: str
- ) -> Tuple[Optional[Responder], dict]:
+ ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -373,15 +379,17 @@ class MediaRepository:
# If we have an entry in the DB, try and look for it
if media_info:
- file_id = media_info["filesystem_id"]
+ file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id)
- if media_info["quarantined_by"]:
+ if media_info.quarantined_by:
logger.info("Media is quarantined")
raise NotFoundError()
- if not media_info["media_type"]:
- media_info["media_type"] = "application/octet-stream"
+ if not media_info.media_type:
+ media_info = attr.evolve(
+ media_info, media_type="application/octet-stream"
+ )
responder = await self.media_storage.fetch_media(file_info)
if responder:
@@ -403,9 +411,9 @@ class MediaRepository:
if not media_info:
raise e
- file_id = media_info["filesystem_id"]
- if not media_info["media_type"]:
- media_info["media_type"] = "application/octet-stream"
+ file_id = media_info.filesystem_id
+ if not media_info.media_type:
+ media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
@@ -415,7 +423,7 @@ class MediaRepository:
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
- server_name, media_id, file_id, media_info["media_type"]
+ server_name, media_id, file_id, media_info.media_type
)
responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +433,7 @@ class MediaRepository:
self,
server_name: str,
media_id: str,
- ) -> dict:
+ ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -518,7 +526,7 @@ class MediaRepository:
origin=server_name,
media_id=media_id,
media_type=media_type,
- time_now_ms=self.clock.time_msec(),
+ time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
@@ -526,15 +534,17 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname)
- media_info = {
- "media_type": media_type,
- "media_length": length,
- "upload_name": upload_name,
- "created_ts": time_now_ms,
- "filesystem_id": file_id,
- }
-
- return media_info
+ return RemoteMedia(
+ media_origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ media_length=length,
+ upload_name=upload_name,
+ created_ts=time_now_ms,
+ filesystem_id=file_id,
+ last_access_ts=time_now_ms,
+ quarantined_by=None,
+ )
def _get_thumbnail_requirements(
self, media_type: str
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 9b5a3dd5f4..44aac21de6 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -240,15 +240,14 @@ class UrlPreviewer:
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
- and cache_result["expires_ts"] > ts
- and cache_result["response_code"] / 100 == 2
+ and cache_result.expires_ts > ts
+ and cache_result.response_code // 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
- og = cache_result["og"]
- if isinstance(og, str):
- og = og.encode("utf8")
- return og
+ if isinstance(cache_result.og, str):
+ return cache_result.og.encode("utf8")
+ return cache_result.og
# If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 85b6bdbe72..efda8b4ab4 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
if not media_info:
respond_404(request)
return
- if media_info["quarantined_by"]:
+ if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
@@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
thumbnail_infos,
media_id,
media_id,
- url_cache=bool(media_info["url_cache"]),
+ url_cache=bool(media_info.url_cache),
server_name=None,
)
@@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
if not media_info:
respond_404(request)
return
- if media_info["quarantined_by"]:
+ if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
@@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
file_info = FileInfo(
server_name=None,
file_id=media_id,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info.url_cache),
thumbnail=info,
)
@@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
desired_height,
desired_method,
desired_type,
- url_cache=bool(media_info["url_cache"]),
+ url_cache=bool(media_info.url_cache),
)
if file_path:
@@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
server_name, media_id
)
- file_id = media_info["filesystem_id"]
+ file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
@@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
- file_id=media_info["filesystem_id"],
+ file_id=file_id,
thumbnail=info,
)
@@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
m_type,
thumbnail_infos,
media_id,
- media_info["filesystem_id"],
+ media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c8d7c9fd32..7f99c64f1b 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
- Any,
Collection,
- Dict,
Iterable,
List,
Optional,
@@ -54,11 +52,32 @@ class LocalMedia:
media_length: int
upload_name: str
created_ts: int
+ url_cache: Optional[str]
last_access_ts: int
quarantined_by: Optional[str]
safe_from_quarantine: bool
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RemoteMedia:
+ media_origin: str
+ media_id: str
+ media_type: str
+ media_length: int
+ upload_name: Optional[str]
+ filesystem_id: str
+ created_ts: int
+ last_access_ts: int
+ quarantined_by: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UrlCache:
+ response_code: int
+ expires_ts: int
+ og: Union[str, bytes]
+
+
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
- async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
+ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"created_ts",
"quarantined_by",
"url_cache",
+ "last_access_ts",
"safe_from_quarantine",
),
allow_none=True,
desc="get_local_media",
)
+ if row is None:
+ return None
+ return LocalMedia(media_id=media_id, **row)
async def get_local_media_by_user_paginate(
self,
@@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
upload_name,
created_ts,
+ url_cache,
last_access_ts,
quarantined_by,
safe_from_quarantine
@@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length=row[2],
upload_name=row[3],
created_ts=row[4],
- last_access_ts=row[5],
- quarantined_by=row[6],
- safe_from_quarantine=bool(row[7]),
+ url_cache=row[5],
+ last_access_ts=row[6],
+ quarantined_by=row[7],
+ safe_from_quarantine=bool(row[8]),
)
for row in txn
]
@@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
- async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
+ async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
"""
- def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
# get the most recently cached result (relative to the given ts)
- sql = (
- "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
- " FROM local_media_repository_url_cache"
- " WHERE url = ? AND download_ts <= ?"
- " ORDER BY download_ts DESC LIMIT 1"
- )
+ sql = """
+ SELECT response_code, expires_ts, og
+ FROM local_media_repository_url_cache
+ WHERE url = ? AND download_ts <= ?
+ ORDER BY download_ts DESC LIMIT 1
+ """
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
- sql = (
- "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
- " FROM local_media_repository_url_cache"
- " WHERE url = ? AND download_ts > ?"
- " ORDER BY download_ts ASC LIMIT 1"
- )
+ sql = """
+ SELECT response_code, expires_ts, og
+ FROM local_media_repository_url_cache
+ WHERE url = ? AND download_ts > ?
+ ORDER BY download_ts ASC LIMIT 1
+ """
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
return None
- return dict(
- zip(
- (
- "response_code",
- "etag",
- "expires_ts",
- "og",
- "media_id",
- "download_ts",
- ),
- row,
- )
- )
+ return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
@@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
response_code: int,
etag: Optional[str],
expires_ts: int,
- og: Optional[str],
+ og: str,
media_id: str,
download_ts: int,
) -> None:
@@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_cached_remote_media(
self, origin: str, media_id: str
- ) -> Optional[Dict[str, Any]]:
- return await self.db_pool.simple_select_one(
+ ) -> Optional[RemoteMedia]:
+ row = await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name",
"created_ts",
"filesystem_id",
+ "last_access_ts",
"quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
)
+ if row is None:
+ return row
+ return RemoteMedia(media_origin=origin, media_id=media_id, **row)
async def store_cached_remote_media(
self,
@@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
t_width: int,
t_height: int,
t_type: str,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThumbnailInfo]:
"""Fetch the thumbnail info of given width, height and type."""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
@@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
- "filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)
+ if row is None:
+ return None
+ return ThumbnailInfo(
+ width=row["thumbnail_width"],
+ height=row["thumbnail_height"],
+ method=row["thumbnail_method"],
+ type=row["thumbnail_type"],
+ length=row["thumbnail_length"],
+ )
@trace
async def store_remote_media_thumbnail(
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 15f5d644e4..a8e7a76b29 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
assert info is not None
- file_id = info["filesystem_id"]
+ file_id = info.filesystem_id
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
origin, file_id
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 278808abb5..dac79bd745 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["quarantined_by"])
+ self.assertFalse(media_info.quarantined_by)
# quarantining
channel = self.make_request(
@@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertTrue(media_info["quarantined_by"])
+ self.assertTrue(media_info.quarantined_by)
# remove from quarantine
channel = self.make_request(
@@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["quarantined_by"])
+ self.assertFalse(media_info.quarantined_by)
def test_quarantine_protected_media(self) -> None:
"""
@@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertTrue(media_info["safe_from_quarantine"])
+ self.assertTrue(media_info.safe_from_quarantine)
# quarantining
channel = self.make_request(
@@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["quarantined_by"])
+ self.assertFalse(media_info.quarantined_by)
class ProtectMediaByIDTestCase(_AdminMediaTests):
@@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["safe_from_quarantine"])
+ self.assertFalse(media_info.safe_from_quarantine)
# protect
channel = self.make_request(
@@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertTrue(media_info["safe_from_quarantine"])
+ self.assertTrue(media_info.safe_from_quarantine)
# unprotect
channel = self.make_request(
@@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["safe_from_quarantine"])
+ self.assertFalse(media_info.safe_from_quarantine)
class PurgeMediaCacheTestCase(_AdminMediaTests):
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index b59d9dfd4d..27a663a23b 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not."""
if mxc_uri.server_name == self.hs.config.server.server_name:
- found_media_dict = self.get_success(
- self.store.get_local_media(mxc_uri.media_id)
+ found_media = bool(
+ self.get_success(self.store.get_local_media(mxc_uri.media_id))
)
else:
- found_media_dict = self.get_success(
- self.store.get_cached_remote_media(
- mxc_uri.server_name, mxc_uri.media_id
+ found_media = bool(
+ self.get_success(
+ self.store.get_cached_remote_media(
+ mxc_uri.server_name, mxc_uri.media_id
+ )
)
)
if expect_purged:
- self.assertIsNone(
- found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
- )
+ self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
else:
- self.assertIsNotNone(
- found_media_dict,
+ self.assertTrue(
+ found_media,
msg=f"{mxc_uri} unexpectedly purged",
)
|