summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-12 09:28:48 -0400
committerGitHub <noreply@github.com>2020-08-12 09:28:48 -0400
commita3a59bab7bb3b69dcfc5620e6f3ac51af3f0f965 (patch)
treeeeec255478ccc92af2111519681831bca5aad289 /synapse/storage/databases/main
parentFix typing for notifier (#8064) (diff)
downloadsynapse-a3a59bab7bb3b69dcfc5620e6f3ac51af3f0f965.tar.xz
Convert appservice, group server, profile and more databases to async (#8066)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/appservice.py34
-rw-r--r--synapse/storage/databases/main/filtering.py8
-rw-r--r--synapse/storage/databases/main/group_server.py86
-rw-r--r--synapse/storage/databases/main/presence.py7
-rw-r--r--synapse/storage/databases/main/profile.py21
-rw-r--r--synapse/storage/databases/main/relations.py19
-rw-r--r--synapse/storage/databases/main/transactions.py7
7 files changed, 78 insertions, 104 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 055a3962dc..5cf1a88399 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -18,8 +18,6 @@ import re
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
 class ApplicationServiceTransactionWorkerStore(
     ApplicationServiceWorkerStore, EventsWorkerStore
 ):
-    @defer.inlineCallbacks
-    def get_appservices_by_state(self, state):
+    async def get_appservices_by_state(self, state):
         """Get a list of application services based on their state.
 
         Args:
             state(ApplicationServiceState): The state to filter on.
         Returns:
-            A Deferred which resolves to a list of ApplicationServices, which
-            may be empty.
+            A list of ApplicationServices, which may be empty.
         """
-        results = yield self.db_pool.simple_select_list(
+        results = await self.db_pool.simple_select_list(
             "application_services_state", {"state": state}, ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
@@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore(
                     services.append(service)
         return services
 
-    @defer.inlineCallbacks
-    def get_appservice_state(self, service):
+    async def get_appservice_state(self, service):
         """Get the application service state.
 
         Args:
             service(ApplicationService): The service whose state to set.
         Returns:
-            A Deferred which resolves to ApplicationServiceState.
+            An ApplicationServiceState.
         """
-        result = yield self.db_pool.simple_select_one(
+        result = await self.db_pool.simple_select_one(
             "application_services_state",
             {"as_id": service.id},
             ["state"],
@@ -270,16 +265,14 @@ class ApplicationServiceTransactionWorkerStore(
             "complete_appservice_txn", _complete_appservice_txn
         )
 
-    @defer.inlineCallbacks
-    def get_oldest_unsent_txn(self, service):
+    async def get_oldest_unsent_txn(self, service):
         """Get the oldest transaction which has not been sent for this
         service.
 
         Args:
             service(ApplicationService): The app service to get the oldest txn.
         Returns:
-            A Deferred which resolves to an AppServiceTransaction or
-            None.
+            An AppServiceTransaction or None.
         """
 
         def _get_oldest_unsent_txn(txn):
@@ -298,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore(
 
             return entry
 
-        entry = yield self.db_pool.runInteraction(
+        entry = await self.db_pool.runInteraction(
             "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
         )
 
@@ -307,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore(
 
         event_ids = db_to_json(entry["event_ids"])
 
-        events = yield self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(event_ids)
 
         return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
 
@@ -332,8 +325,7 @@ class ApplicationServiceTransactionWorkerStore(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
-    @defer.inlineCallbacks
-    def get_new_events_for_appservice(self, current_id, limit):
+    async def get_new_events_for_appservice(self, current_id, limit):
         """Get all new evnets"""
 
         def get_new_events_for_appservice_txn(txn):
@@ -357,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore(
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.db_pool.runInteraction(
+        upper_bound, event_ids = await self.db_pool.runInteraction(
             "get_new_events_for_appservice", get_new_events_for_appservice_txn
         )
 
-        events = yield self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(event_ids)
 
         return upper_bound, events
 
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cae6bda80e..45a1760170 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 
 class FilteringStore(SQLBaseStore):
-    @cachedInlineCallbacks(num_args=2)
-    def get_user_filter(self, user_localpart, filter_id):
+    @cached(num_args=2)
+    async def get_user_filter(self, user_localpart, filter_id):
         # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
         # with a coherent error message rather than 500 M_UNKNOWN.
         try:
@@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
         except ValueError:
             raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
 
-        def_json = yield self.db_pool.simple_select_one_onecol(
+        def_json = await self.db_pool.simple_select_one_onecol(
             table="user_filters",
             keyvalues={"user_id": user_localpart, "filter_id": filter_id},
             retcol="filter_json",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 75ea6d4b2f..380db3a3f3 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,12 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Tuple
-
-from twisted.internet import defer
+from typing import List, Optional, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 
 # The category ID for the "default" category. We don't store as null in the
@@ -210,9 +209,8 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_rooms_for_summary", _get_rooms_for_summary_txn
         )
 
-    @defer.inlineCallbacks
-    def get_group_categories(self, group_id):
-        rows = yield self.db_pool.simple_select_list(
+    async def get_group_categories(self, group_id):
+        rows = await self.db_pool.simple_select_list(
             table="group_room_categories",
             keyvalues={"group_id": group_id},
             retcols=("category_id", "is_public", "profile"),
@@ -227,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    @defer.inlineCallbacks
-    def get_group_category(self, group_id, category_id):
-        category = yield self.db_pool.simple_select_one(
+    async def get_group_category(self, group_id, category_id):
+        category = await self.db_pool.simple_select_one(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             retcols=("is_public", "profile"),
@@ -240,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return category
 
-    @defer.inlineCallbacks
-    def get_group_roles(self, group_id):
-        rows = yield self.db_pool.simple_select_list(
+    async def get_group_roles(self, group_id):
+        rows = await self.db_pool.simple_select_list(
             table="group_roles",
             keyvalues={"group_id": group_id},
             retcols=("role_id", "is_public", "profile"),
@@ -257,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    @defer.inlineCallbacks
-    def get_group_role(self, group_id, role_id):
-        role = yield self.db_pool.simple_select_one(
+    async def get_group_role(self, group_id, role_id):
+        role = await self.db_pool.simple_select_one(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             retcols=("is_public", "profile"),
@@ -448,12 +443,11 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_attestations_need_renewals", _get_attestations_need_renewals_txn
         )
 
-    @defer.inlineCallbacks
-    def get_remote_attestation(self, group_id, user_id):
+    async def get_remote_attestation(self, group_id, user_id):
         """Get the attestation that proves the remote agrees that the user is
         in the group.
         """
-        row = yield self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcols=("valid_until_ms", "attestation_json"),
@@ -499,13 +493,13 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_all_groups_for_user", _get_all_groups_for_user_txn
         )
 
-    def get_groups_changes_for_user(self, user_id, from_token, to_token):
+    async def get_groups_changes_for_user(self, user_id, from_token, to_token):
         from_token = int(from_token)
         has_changed = self._group_updates_stream_cache.has_entity_changed(
             user_id, from_token
         )
         if not has_changed:
-            return defer.succeed([])
+            return []
 
         def _get_groups_changes_for_user_txn(txn):
             sql = """
@@ -525,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
                 for group_id, membership, gtype, content_json in txn
             ]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_groups_changes_for_user", _get_groups_changes_for_user_txn
         )
 
@@ -1087,31 +1081,31 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="update_group_publicity",
         )
 
-    @defer.inlineCallbacks
-    def register_user_group_membership(
+    async def register_user_group_membership(
         self,
-        group_id,
-        user_id,
-        membership,
-        is_admin=False,
-        content={},
-        local_attestation=None,
-        remote_attestation=None,
-        is_publicised=False,
-    ):
+        group_id: str,
+        user_id: str,
+        membership: str,
+        is_admin: bool = False,
+        content: JsonDict = {},
+        local_attestation: Optional[dict] = None,
+        remote_attestation: Optional[dict] = None,
+        is_publicised: bool = False,
+    ) -> int:
         """Registers that a local user is a member of a (local or remote) group.
 
         Args:
-            group_id (str)
-            user_id (str)
-            membership (str)
-            is_admin (bool)
-            content (dict): Content of the membership, e.g. includes the inviter
+            group_id: The group the member is being added to.
+            user_id: THe user ID to add to the group.
+            membership: The type of group membership.
+            is_admin: Whether the user should be added as a group admin.
+            content: Content of the membership, e.g. includes the inviter
                 if the user has been invited.
-            local_attestation (dict): If remote group then store the fact that we
+            local_attestation: If remote group then store the fact that we
                 have given out an attestation, else None.
-            remote_attestation (dict): If remote group then store the remote
+            remote_attestation: If remote group then store the remote
                 attestation from the group, else None.
+            is_publicised: Whether this should be publicised.
         """
 
         def _register_user_group_membership_txn(txn, next_id):
@@ -1188,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore):
             return next_id
 
         with self._group_updates_id_gen.get_next() as next_id:
-            res = yield self.db_pool.runInteraction(
+            res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
                 next_id,
             )
         return res
 
-    @defer.inlineCallbacks
-    def create_group(
+    async def create_group(
         self, group_id, user_id, name, avatar_url, short_description, long_description
-    ):
-        yield self.db_pool.simple_insert(
+    ) -> None:
+        await self.db_pool.simple_insert(
             table="groups",
             values={
                 "group_id": group_id,
@@ -1212,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="create_group",
         )
 
-    @defer.inlineCallbacks
-    def update_group_profile(self, group_id, profile):
-        yield self.db_pool.simple_update_one(
+    async def update_group_profile(self, group_id, profile):
+        await self.db_pool.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues=profile,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 99e66dc6e9..59ba12820a 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,8 +15,6 @@
 
 from typing import List, Tuple
 
-from twisted.internet import defer
-
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.presence import UserPresenceState
 from synapse.util.caches.descriptors import cached, cachedList
@@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter
 
 
 class PresenceStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def update_presence(self, presence_states):
+    async def update_presence(self, presence_states):
         stream_ordering_manager = self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
         with stream_ordering_manager as stream_orderings:
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "update_presence",
                 self._update_presence_txn,
                 stream_orderings,
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 4a4f2cb385..b8261357d4 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -13,18 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.databases.main.roommember import ProfileInfo
 
 
 class ProfileWorkerStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def get_profileinfo(self, user_localpart):
+    async def get_profileinfo(self, user_localpart):
         try:
-            profile = yield self.db_pool.simple_select_one(
+            profile = await self.db_pool.simple_select_one(
                 table="profiles",
                 keyvalues={"user_id": user_localpart},
                 retcols=("displayname", "avatar_url"),
@@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore):
             desc="update_remote_profile_cache",
         )
 
-    @defer.inlineCallbacks
-    def maybe_delete_remote_profile_cache(self, user_id):
+    async def maybe_delete_remote_profile_cache(self, user_id):
         """Check if we still care about the remote user's profile, and if we
         don't then remove their profile from the cache
         """
-        subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+        subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
         if not subscribed:
-            yield self.db_pool.simple_delete(
+            await self.db_pool.simple_delete(
                 table="remote_profile_cache",
                 keyvalues={"user_id": user_id},
                 desc="delete_remote_profile_cache",
@@ -151,11 +147,10 @@ class ProfileStore(ProfileWorkerStore):
             _get_remote_profile_cache_entries_that_expire_txn,
         )
 
-    @defer.inlineCallbacks
-    def is_subscribed_remote_profile_for_user(self, user_id):
+    async def is_subscribed_remote_profile_for_user(self, user_id):
         """Check whether we are interested in a remote user's profile.
         """
-        res = yield self.db_pool.simple_select_one_onecol(
+        res = await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"user_id": user_id},
             retcol="user_id",
@@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore):
         if res:
             return True
 
-        res = yield self.db_pool.simple_select_one_onecol(
+        res = await self.db_pool.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"user_id": user_id},
             retcol="user_id",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b81f1449b7..a9ceffc20e 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,10 +14,12 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 import attr
 
 from synapse.api.constants import RelationTypes
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
@@ -25,7 +27,7 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
@@ -227,18 +229,18 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
-    @cachedInlineCallbacks()
-    def get_applicable_edit(self, event_id):
+    @cached()
+    async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
         """Get the most recent edit (if any) that has happened for the given
         event.
 
         Correctly handles checking whether edits were allowed to happen.
 
         Args:
-            event_id (str): The original event ID
+            event_id: The original event ID
 
         Returns:
-            Deferred[EventBase|None]: Returns the most recent edit, if any.
+            The most recent edit, if any.
         """
 
         # We only allow edits for `m.room.message` events that have the same sender
@@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore):
             if row:
                 return row[0]
 
-        edit_id = yield self.db_pool.runInteraction(
+        edit_id = await self.db_pool.runInteraction(
             "get_applicable_edit", _get_applicable_edit_txn
         )
 
         if not edit_id:
-            return
+            return None
 
-        edit_event = yield self.get_event(edit_id, allow_none=True)
-        return edit_event
+        return await self.get_event(edit_id, allow_none=True)
 
     def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
         """Check if a user has already annotated an event with the same key
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 8804c0e4ac..52668dbdf9 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -18,8 +18,6 @@ from collections import namedtuple
 
 from canonicaljson import encode_canonical_json
 
-from twisted.internet import defer
-
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
@@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore):
             desc="set_received_txn_response",
         )
 
-    @defer.inlineCallbacks
-    def get_destination_retry_timings(self, destination):
+    async def get_destination_retry_timings(self, destination):
         """Gets the current retry timings (if any) for a given destination.
 
         Args:
@@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore):
         if result is not SENTINEL:
             return result
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings,
             destination,