summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/11311.misc1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py8
-rw-r--r--synapse/storage/databases/main/media_repository.py141
4 files changed, 89 insertions, 62 deletions
diff --git a/changelog.d/11311.misc b/changelog.d/11311.misc
new file mode 100644
index 0000000000..86594a332d
--- /dev/null
+++ b/changelog.d/11311.misc
@@ -0,0 +1 @@
+Add type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index d81e964dc7..56a62bb9b7 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -36,7 +36,6 @@ exclude = (?x)
    |synapse/storage/databases/main/events_bg_updates.py
    |synapse/storage/databases/main/events_worker.py
    |synapse/storage/databases/main/group_server.py
-   |synapse/storage/databases/main/media_repository.py
    |synapse/storage/databases/main/metrics.py
    |synapse/storage/databases/main/monthly_active_users.py
    |synapse/storage/databases/main/presence.py
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8ca97b5b18..054f3c296d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.oembed import OEmbedProvider
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -231,7 +231,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         og = await make_deferred_yieldable(observable.observe())
         respond_with_json_bytes(request, 200, og, send_cors=True)
 
-    async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
+    async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
         """Check the db, and download the URL and build a preview
 
         Args:
@@ -360,7 +360,7 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         return jsonog.encode("utf8")
 
-    async def _download_url(self, url: str, user: str) -> MediaInfo:
+    async def _download_url(self, url: str, user: UserID) -> MediaInfo:
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         )
 
     async def _precache_image_url(
-        self, user: str, media_info: MediaInfo, og: JsonDict
+        self, user: UserID, media_info: MediaInfo, og: JsonDict
     ) -> None:
         """
         Pre-cache the image (if one exists) for posterity
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 717487be28..1b076683f7 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,10 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -46,7 +61,12 @@ class MediaSortOrder(Enum):
 
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             self._drop_media_index_without_method,
         )
 
-    async def _drop_media_index_without_method(self, progress, batch_size):
+    async def _drop_media_index_without_method(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """background update handler which removes the old constraints.
 
         Note that this is only run on postgres.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             txn.execute(
                 "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
             )
@@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self.server_name = hs.hostname
 
@@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             plus the total count of all the user's media
         """
 
-        def get_local_media_by_user_paginate_txn(txn):
+        def get_local_media_by_user_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
 
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
@@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             else:
                 order = "ASC"
 
-            args = [user_id]
+            args: List[Union[str, int]] = [user_id]
             sql = """
                 SELECT COUNT(*) as total_media
                 FROM local_media_repository
                 WHERE user_id = ?
             """
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = txn.fetchone()[0]  # type: ignore[index]
 
             sql = """
                 SELECT
@@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             )
             sql += sql_keep
 
-        def _get_local_media_before_txn(txn):
+        def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (before_ts, before_ts, size_gt))
             return [row[0] for row in txn]
 
@@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_local_media(
         self,
-        media_id,
-        media_type,
-        time_now_ms,
-        upload_name,
-        media_length,
-        user_id,
-        url_cache=None,
+        media_id: str,
+        media_type: str,
+        time_now_ms: int,
+        upload_name: Optional[str],
+        media_length: int,
+        user_id: UserID,
+        url_cache: Optional[str] = None,
     ) -> None:
         await self.db_pool.simple_insert(
             "local_media_repository",
@@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             None if the URL isn't cached.
         """
 
-        def get_url_cache_txn(txn):
+        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
             # get the most recently cached result (relative to the given ts)
             sql = (
                 "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
@@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
-    ):
+    ) -> None:
         await self.db_pool.simple_insert(
             "local_media_repository_url_cache",
             {
@@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_local_thumbnail(
         self,
-        media_id,
-        thumbnail_width,
-        thumbnail_height,
-        thumbnail_type,
-        thumbnail_method,
-        thumbnail_length,
-    ):
+        media_id: str,
+        thumbnail_width: int,
+        thumbnail_height: int,
+        thumbnail_type: str,
+        thumbnail_method: str,
+        thumbnail_length: int,
+    ) -> None:
         await self.db_pool.simple_upsert(
             table="local_media_repository_thumbnails",
             keyvalues={
@@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_cached_remote_media(
         self,
-        origin,
-        media_id,
-        media_type,
-        media_length,
-        time_now_ms,
-        upload_name,
-        filesystem_id,
-    ):
+        origin: str,
+        media_id: str,
+        media_type: str,
+        media_length: int,
+        time_now_ms: int,
+        upload_name: Optional[str],
+        filesystem_id: str,
+    ) -> None:
         await self.db_pool.simple_insert(
             "remote_media_cache",
             {
@@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         local_media: Iterable[str],
         remote_media: Iterable[Tuple[str, str]],
         time_ms: int,
-    ):
+    ) -> None:
         """Updates the last access time of the given media
 
         Args:
@@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             time_ms: Current time in milliseconds
         """
 
-        def update_cache_txn(txn):
+        def update_cache_txn(txn: LoggingTransaction) -> None:
             sql = (
                 "UPDATE remote_media_cache SET last_access_ts = ?"
                 " WHERE media_origin = ? AND media_id = ?"
@@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
         )
 
@@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_remote_media_thumbnail(
         self,
-        origin,
-        media_id,
-        filesystem_id,
-        thumbnail_width,
-        thumbnail_height,
-        thumbnail_type,
-        thumbnail_method,
-        thumbnail_length,
-    ):
+        origin: str,
+        media_id: str,
+        filesystem_id: str,
+        thumbnail_width: int,
+        thumbnail_height: int,
+        thumbnail_type: str,
+        thumbnail_method: str,
+        thumbnail_length: int,
+    ) -> None:
         await self.db_pool.simple_upsert(
             table="remote_media_cache_thumbnails",
             keyvalues={
@@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_remote_media_thumbnail",
         )
 
-    async def get_remote_media_before(self, before_ts):
+    async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
         sql = (
             "SELECT media_origin, media_id, filesystem_id"
             " FROM remote_media_cache"
@@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             " LIMIT 500"
         )
 
-        def _get_expired_url_cache_txn(txn):
+        def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
@@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_expired_url_cache", _get_expired_url_cache_txn
         )
 
-    async def delete_url_cache(self, media_ids):
+    async def delete_url_cache(self, media_ids: Collection[str]) -> None:
         if len(media_ids) == 0:
             return
 
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
-        def _delete_url_cache_txn(txn):
+        def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
             txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
-        return await self.db_pool.runInteraction(
-            "delete_url_cache", _delete_url_cache_txn
-        )
+        await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
 
     async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
         sql = (
@@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             " LIMIT 500"
         )
 
-        def _get_url_cache_media_before_txn(txn):
+        def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
@@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
-    async def delete_url_cache_media(self, media_ids):
+    async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
         if len(media_ids) == 0:
             return
 
-        def _delete_url_cache_media_txn(txn):
+        def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
             txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
@@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
         )