summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-11-12 11:05:26 -0500
committerGitHub <noreply@github.com>2021-11-12 11:05:26 -0500
commit9b90b9454b8855e0575785560662d8e47378094d (patch)
tree25a3f317828af7ea505372110390d5c5f1ab5a77 /synapse/storage/databases/main
parentAttempt to annotate events_forward_extremities (#11314) (diff)
downloadsynapse-9b90b9454b8855e0575785560662d8e47378094d.tar.xz
Add type hints to media repository storage module (#11311)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/media_repository.py141
1 files changed, 84 insertions, 57 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 717487be28..1b076683f7 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,10 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -46,7 +61,12 @@ class MediaSortOrder(Enum):
 
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             self._drop_media_index_without_method,
         )
 
-    async def _drop_media_index_without_method(self, progress, batch_size):
+    async def _drop_media_index_without_method(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """background update handler which removes the old constraints.
 
         Note that this is only run on postgres.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             txn.execute(
                 "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
             )
@@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self.server_name = hs.hostname
 
@@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             plus the total count of all the user's media
         """
 
-        def get_local_media_by_user_paginate_txn(txn):
+        def get_local_media_by_user_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
 
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
@@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             else:
                 order = "ASC"
 
-            args = [user_id]
+            args: List[Union[str, int]] = [user_id]
             sql = """
                 SELECT COUNT(*) as total_media
                 FROM local_media_repository
                 WHERE user_id = ?
             """
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = txn.fetchone()[0]  # type: ignore[index]
 
             sql = """
                 SELECT
@@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             )
             sql += sql_keep
 
-        def _get_local_media_before_txn(txn):
+        def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (before_ts, before_ts, size_gt))
             return [row[0] for row in txn]
 
@@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_local_media(
         self,
-        media_id,
-        media_type,
-        time_now_ms,
-        upload_name,
-        media_length,
-        user_id,
-        url_cache=None,
+        media_id: str,
+        media_type: str,
+        time_now_ms: int,
+        upload_name: Optional[str],
+        media_length: int,
+        user_id: UserID,
+        url_cache: Optional[str] = None,
     ) -> None:
         await self.db_pool.simple_insert(
             "local_media_repository",
@@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             None if the URL isn't cached.
         """
 
-        def get_url_cache_txn(txn):
+        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
             # get the most recently cached result (relative to the given ts)
             sql = (
                 "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
@@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
-    ):
+    ) -> None:
         await self.db_pool.simple_insert(
             "local_media_repository_url_cache",
             {
@@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_local_thumbnail(
         self,
-        media_id,
-        thumbnail_width,
-        thumbnail_height,
-        thumbnail_type,
-        thumbnail_method,
-        thumbnail_length,
-    ):
+        media_id: str,
+        thumbnail_width: int,
+        thumbnail_height: int,
+        thumbnail_type: str,
+        thumbnail_method: str,
+        thumbnail_length: int,
+    ) -> None:
         await self.db_pool.simple_upsert(
             table="local_media_repository_thumbnails",
             keyvalues={
@@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_cached_remote_media(
         self,
-        origin,
-        media_id,
-        media_type,
-        media_length,
-        time_now_ms,
-        upload_name,
-        filesystem_id,
-    ):
+        origin: str,
+        media_id: str,
+        media_type: str,
+        media_length: int,
+        time_now_ms: int,
+        upload_name: Optional[str],
+        filesystem_id: str,
+    ) -> None:
         await self.db_pool.simple_insert(
             "remote_media_cache",
             {
@@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         local_media: Iterable[str],
         remote_media: Iterable[Tuple[str, str]],
         time_ms: int,
-    ):
+    ) -> None:
         """Updates the last access time of the given media
 
         Args:
@@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             time_ms: Current time in milliseconds
         """
 
-        def update_cache_txn(txn):
+        def update_cache_txn(txn: LoggingTransaction) -> None:
             sql = (
                 "UPDATE remote_media_cache SET last_access_ts = ?"
                 " WHERE media_origin = ? AND media_id = ?"
@@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
         )
 
@@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def store_remote_media_thumbnail(
         self,
-        origin,
-        media_id,
-        filesystem_id,
-        thumbnail_width,
-        thumbnail_height,
-        thumbnail_type,
-        thumbnail_method,
-        thumbnail_length,
-    ):
+        origin: str,
+        media_id: str,
+        filesystem_id: str,
+        thumbnail_width: int,
+        thumbnail_height: int,
+        thumbnail_type: str,
+        thumbnail_method: str,
+        thumbnail_length: int,
+    ) -> None:
         await self.db_pool.simple_upsert(
             table="remote_media_cache_thumbnails",
             keyvalues={
@@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_remote_media_thumbnail",
         )
 
-    async def get_remote_media_before(self, before_ts):
+    async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
         sql = (
             "SELECT media_origin, media_id, filesystem_id"
             " FROM remote_media_cache"
@@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             " LIMIT 500"
         )
 
-        def _get_expired_url_cache_txn(txn):
+        def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
@@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_expired_url_cache", _get_expired_url_cache_txn
         )
 
-    async def delete_url_cache(self, media_ids):
+    async def delete_url_cache(self, media_ids: Collection[str]) -> None:
         if len(media_ids) == 0:
             return
 
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
-        def _delete_url_cache_txn(txn):
+        def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
             txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
-        return await self.db_pool.runInteraction(
-            "delete_url_cache", _delete_url_cache_txn
-        )
+        await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
 
     async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
         sql = (
@@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             " LIMIT 500"
         )
 
-        def _get_url_cache_media_before_txn(txn):
+        def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
@@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
-    async def delete_url_cache_media(self, media_ids):
+    async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
         if len(media_ids) == 0:
             return
 
-        def _delete_url_cache_media_txn(txn):
+        def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
             txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
@@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
         )