summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/client_ips.py4
-rw-r--r--synapse/storage/databases/main/directory.py6
-rw-r--r--synapse/storage/databases/main/filtering.py5
-rw-r--r--synapse/storage/databases/main/openid.py8
-rw-r--r--synapse/storage/databases/main/profile.py6
-rw-r--r--synapse/storage/databases/main/push_rule.py10
-rw-r--r--synapse/storage/databases/main/room.py49
-rw-r--r--synapse/storage/databases/main/signatures.py40
-rw-r--r--synapse/storage/databases/main/ui_auth.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py8
10 files changed, 90 insertions, 50 deletions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 216a5925fc..c2fc847fbc 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         self._batch_row_update[key] = (user_agent, device_id, now)
 
     @wrap_as_background_process("update_client_ips")
-    def _update_client_ips_batch(self):
+    async def _update_client_ips_batch(self) -> None:
 
         # If the DB pool has already terminated, don't try updating
         if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         to_update = self._batch_row_update
         self._batch_row_update = {}
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 405b5eafa5..e5060d4c46 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
 
         return room_id
 
-    def update_aliases_for_room(
+    async def update_aliases_for_room(
         self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
-    ):
+    ) -> None:
         """Repoint all of the aliases for a given room, to a different room.
 
         Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
                 txn, self.get_aliases_for_room, (new_room_id,)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_update_aliases_for_room_txn", _update_aliases_for_room_txn
         )
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-    def add_user_filter(self, user_localpart, user_filter):
+    async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
         def_json = encode_canonical_json(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
 
             return filter_id
 
-        return self.db_pool.runInteraction("add_user_filter", _do_txn)
+        return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 4db8949da7..2aac64901b 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 from synapse.storage._base import SQLBaseStore
 
 
@@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
             desc="insert_open_id_token",
         )
 
-    def get_user_id_for_open_id_token(self, token, ts_now_ms):
+    async def get_user_id_for_open_id_token(
+        self, token: str, ts_now_ms: int
+    ) -> Optional[str]:
         def get_user_id_for_token_txn(txn):
             sql = (
                 "SELECT user_id FROM open_id_tokens"
@@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
             else:
                 return rows[0][0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_user_id_for_token", get_user_id_for_token_txn
         )
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 301875a672..d2e0685e9e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
                 desc="delete_remote_profile_cache",
             )
 
-    def get_remote_profile_cache_entries_that_expire(self, last_checked):
+    async def get_remote_profile_cache_entries_that_expire(
+        self, last_checked: int
+    ) -> Dict[str, str]:
         """Get all users who haven't been checked since `last_checked`
         """
 
@@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
 
             return self.db_pool.cursor_to_dict(txn)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_remote_profile_cache_entries_that_expire",
             _get_remote_profile_cache_entries_that_expire_txn,
         )
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2fb5b02d7d..0de802a86b 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc
 import logging
 from typing import List, Tuple, Union
 
-from twisted.internet import defer
-
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,9 +147,11 @@ class PushRulesWorkerStore(
         )
         return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
+    async def have_push_rules_changed_for_user(
+        self, user_id: str, last_id: int
+    ) -> bool:
         if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
+            return False
         else:
 
             def have_push_rules_changed_txn(txn):
@@ -163,7 +163,7 @@ class PushRulesWorkerStore(
                 (count,) = txn.fetchone()
                 return bool(count)
 
-            return self.db_pool.runInteraction(
+            return await self.db_pool.runInteraction(
                 "have_push_rules_changed", have_push_rules_changed_txn
             )
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a92641c339..717df97301 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    def get_room_with_stats(self, room_id: str):
+    async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
         """Retrieve room with statistics.
 
         Args:
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
             res["public"] = bool(res["public"])
             return res
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_room_with_stats", get_room_with_stats_txn, room_id
         )
 
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
             desc="get_public_room_ids",
         )
 
-    def count_public_rooms(self, network_tuple, ignore_non_federatable):
+    async def count_public_rooms(
+        self,
+        network_tuple: Optional[ThirdPartyInstanceID],
+        ignore_non_federatable: bool,
+    ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
 
         Args:
-            network_tuple (ThirdPartyInstanceID|None)
-            ignore_non_federatable (bool): If true filters out non-federatable rooms
+            network_tuple
+            ignore_non_federatable: If true filters out non-federatable rooms
         """
 
         def _count_public_rooms_txn(txn):
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
             txn.execute(sql, query_args)
             return txn.fetchone()[0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_public_rooms", _count_public_rooms_txn
         )
 
@@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
 
         return row
 
-    def get_media_mxcs_in_room(self, room_id):
+    async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
 
         Args:
-            room_id (str)
+            room_id
 
         Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
+            The local and remote media as a lists of the media IDs.
         """
 
         def _get_media_mxcs_in_room_txn(txn):
@@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
 
             return local_media_mxcs, remote_media_mxcs
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_media_ids_in_room", _get_media_mxcs_in_room_txn
         )
 
-    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+    async def quarantine_media_ids_in_room(
+        self, room_id: str, quarantined_by: str
+    ) -> int:
         """For a room loops through all events with media and quarantines
         the associated media
         """
@@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
@@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
 
         return local_media_mxcs, remote_media_mxcs
 
-    def quarantine_media_by_id(
+    async def quarantine_media_by_id(
         self, server_name: str, media_id: str, quarantined_by: str,
-    ):
+    ) -> int:
         """quarantines a single local or remote media id
 
         Args:
@@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_id_txn
         )
 
-    def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+    async def quarantine_media_ids_by_user(
+        self, user_id: str, quarantined_by: str
+    ) -> int:
         """quarantines all local media associated with a single user
 
         Args:
@@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
             local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
             return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_user_txn
         )
 
@@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    def get_room_count(self):
-        """Retrieve a list of all rooms
+    async def get_room_count(self) -> int:
+        """Retrieve the total number of rooms.
         """
 
         def f(txn):
@@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             row = txn.fetchone()
             return row[0] or 0
 
-        return self.db_pool.runInteraction("get_rooms", f)
+        return await self.db_pool.runInteraction("get_rooms", f)
 
     async def add_event_report(
         self,
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Dict, Iterable, List, Tuple
+
 from unpaddedbase64 import encode_base64
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
 
 
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
     @cachedList(
         cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
     )
-    def get_event_reference_hashes(self, event_ids):
+    async def get_event_reference_hashes(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, Dict[str, bytes]]:
+        """Get all hashes for given events.
+
+        Args:
+            event_ids: The event IDs to get hashes for.
+
+        Returns:
+             A mapping of event ID to a mapping of algorithm to hash.
+        """
+
         def f(txn):
             return {
                 event_id: self._get_event_reference_hashes_txn(txn, event_id)
                 for event_id in event_ids
             }
 
-        return self.db_pool.runInteraction("get_event_reference_hashes", f)
+        return await self.db_pool.runInteraction("get_event_reference_hashes", f)
 
-    async def add_event_hashes(self, event_ids):
+    async def add_event_hashes(
+        self, event_ids: Iterable[str]
+    ) -> List[Tuple[str, Dict[str, str]]]:
+        """
+
+        Args:
+            event_ids: The event IDs
+
+        Returns:
+            A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+        """
         hashes = await self.get_event_reference_hashes(event_ids)
         hashes = {
             e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
 
         return list(hashes.items())
 
-    def _get_event_reference_hashes_txn(self, txn, event_id):
+    def _get_event_reference_hashes_txn(
+        self, txn: Cursor, event_id: str
+    ) -> Dict[str, bytes]:
         """Get all the hashes for a given PDU.
         Args:
-            txn (cursor):
-            event_id (str): Id for the Event.
+            txn:
+            event_id: Id for the Event.
         Returns:
-            A dict[unicode, bytes] of algorithm -> hash.
+            A mapping of algorithm -> hash.
         """
         query = (
             "SELECT algorithm, hash"
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 9eef8e57c5..b89668d561 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
 
 
 class UIAuthStore(UIAuthWorkerStore):
-    def delete_old_ui_auth_sessions(self, expiration_time: int):
+    async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
         """
         Remove sessions which were last used earlier than the expiration time.
 
@@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
                 This is an epoch time in milliseconds.
 
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_old_ui_auth_sessions",
             self._delete_old_ui_auth_sessions_txn,
             expiration_time,
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index e3547e53b3..2f7c95fc74 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
 
 
 class UserErasureStore(UserErasureWorkerStore):
-    def mark_user_erased(self, user_id: str) -> None:
+    async def mark_user_erased(self, user_id: str) -> None:
         """Indicate that user_id wishes their message history to be erased.
 
         Args:
@@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.db_pool.runInteraction("mark_user_erased", f)
+        await self.db_pool.runInteraction("mark_user_erased", f)
 
-    def mark_user_not_erased(self, user_id: str) -> None:
+    async def mark_user_not_erased(self, user_id: str) -> None:
         """Indicate that user_id is no longer erased.
 
         Args:
@@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.db_pool.runInteraction("mark_user_not_erased", f)
+        await self.db_pool.runInteraction("mark_user_not_erased", f)