summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-14 07:24:26 -0400
committerGitHub <noreply@github.com>2020-08-14 07:24:26 -0400
commit894dae74fe8e79911c3c001c8b84620ef3985bf6 (patch)
tree3ba58a280c6db5eb89e116ef528061592ce8f987 /synapse
parentRemove a space at the start of a changelog entry. (diff)
downloadsynapse-894dae74fe8e79911c3c001c8b84620ef3985bf6.tar.xz
Convert misc database code to async (#8087)
Diffstat (limited to '')
-rw-r--r--synapse/storage/background_updates.py14
-rw-r--r--synapse/storage/databases/main/devices.py5
-rw-r--r--synapse/storage/databases/main/event_push_actions.py9
-rw-r--r--synapse/storage/databases/main/presence.py9
-rw-r--r--synapse/storage/databases/main/push_rule.py16
-rw-r--r--synapse/storage/databases/main/pusher.py9
-rw-r--r--synapse/storage/databases/main/receipts.py5
-rw-r--r--synapse/storage/databases/main/roommember.py17
-rw-r--r--synapse/storage/databases/main/state.py5
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py13
10 files changed, 38 insertions, 64 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index f43463df53..90a1f9e8b1 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -18,8 +18,6 @@ from typing import Optional
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.metrics.background_process_metrics import run_as_background_process
 
 from . import engines
@@ -308,9 +306,8 @@ class BackgroundUpdater(object):
             update_name (str): Name of update
         """
 
-        @defer.inlineCallbacks
-        def noop_update(progress, batch_size):
-            yield self._end_background_update(update_name)
+        async def noop_update(progress, batch_size):
+            await self._end_background_update(update_name)
             return 1
 
         self.register_background_update_handler(update_name, noop_update)
@@ -409,12 +406,11 @@ class BackgroundUpdater(object):
         else:
             runner = create_index_sqlite
 
-        @defer.inlineCallbacks
-        def updater(progress, batch_size):
+        async def updater(progress, batch_size):
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
-                yield self.db_pool.runWithConnection(runner)
-            yield self._end_background_update(update_name)
+                await self.db_pool.runWithConnection(runner)
+            await self._end_background_update(update_name)
             return 1
 
         self.register_background_update_handler(update_name, updater)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..9a786e2929 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
     @cachedList(
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
-        inlineCallbacks=True,
     )
-    def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+        rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
             iterable=user_ids,
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..e8834b2162 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self._rotate_delay = 3
         self._rotate_count = 10000
 
-    @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
-    def get_unread_event_push_actions_by_room_for_user(
+    @cached(num_args=3, tree=True, max_entries=5000)
+    async def get_unread_event_push_actions_by_room_for_user(
         self, room_id, user_id, last_read_event_id
     ):
-        ret = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
             room_id,
             user_id,
             last_read_event_id,
         )
-        return ret
 
     def _get_unread_counts_by_receipt_txn(
         self, txn, room_id, user_id, last_read_event_id
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..fd213d2dfd 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_presence_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
     )
-    def get_presence_for_users(self, user_ids):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_presence_for_users(self, user_ids):
+        rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
             iterable=user_ids,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..6aa5802977 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -170,18 +170,15 @@ class PushRulesWorkerStore(
             )
 
     @cachedList(
-        cached_method_name="get_push_rules_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
     )
-    def bulk_get_push_rules(self, user_ids):
+    async def bulk_get_push_rules(self, user_ids):
         if not user_ids:
             return {}
 
         results = {user_id: [] for user_id in user_ids}
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
             column="user_name",
             iterable=user_ids,
@@ -194,7 +191,7 @@ class PushRulesWorkerStore(
         for row in rows:
             results.setdefault(row["user_name"], []).append(row)
 
-        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+        enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
         for user_id, rules in results.items():
             use_new_defaults = user_id in self._users_new_default_push_rules
@@ -260,15 +257,14 @@ class PushRulesWorkerStore(
         cached_method_name="get_push_rules_enabled_for_user",
         list_name="user_ids",
         num_args=1,
-        inlineCallbacks=True,
     )
-    def bulk_get_push_rules_enabled(self, user_ids):
+    async def bulk_get_push_rules_enabled(self, user_ids):
         if not user_ids:
             return {}
 
         results = {user_id: {} for user_id in user_ids}
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="push_rules_enable",
             column="user_name",
             iterable=user_ids,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..8b793d1487 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -170,13 +170,10 @@ class PusherWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="get_if_user_has_pusher",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
     )
-    def get_if_users_have_pushers(self, user_ids):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_if_users_have_pushers(self, user_ids):
+        rows = await self.db_pool.simple_select_many_batch(
             table="pushers",
             column="user_name",
             iterable=user_ids,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..579b7bb17b 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -212,9 +212,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
         cached_method_name="_get_linearized_receipts_for_room",
         list_name="room_ids",
         num_args=3,
-        inlineCallbacks=True,
     )
-    def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
             return {}
 
@@ -243,7 +242,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return self.db_pool.cursor_to_dict(txn)
 
-        txn_results = yield self.db_pool.runInteraction(
+        txn_results = await self.db_pool.runInteraction(
             "_get_linearized_receipts_for_rooms", f
         )
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..1cc8c08ed0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -17,8 +17,6 @@
 import logging
 from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 lambda: self._known_servers_count,
             )
 
-    @defer.inlineCallbacks
-    def _count_known_servers(self):
+    async def _count_known_servers(self):
         """
         Count the servers that this server knows about.
 
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(query)
             return list(txn)[0][0]
 
-        count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+        count = await self.db_pool.runInteraction("get_known_servers", _transact)
 
         # We always know about ourselves, even if we have nothing in
         # room_memberships (for example, the server is new).
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
-        list_name="event_ids",
-        inlineCallbacks=True,
+        cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
     )
-    def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+    async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_ids: The member event IDs to lookup
 
         Returns:
-            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+            dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
             to `user_id` and ProfileInfo (or None if not join event).
         """
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..991233a9bc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         cached_method_name="_get_state_group_for_event",
         list_name="event_ids",
         num_args=1,
-        inlineCallbacks=True,
     )
-    def _get_state_group_for_events(self, event_ids):
+    async def _get_state_group_for_events(self, event_ids):
         """Returns mapping event_id -> state_group
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
             iterable=event_ids,
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..da23fe7355 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
             desc="is_user_erased",
         ).addCallback(operator.truth)
 
-    @cachedList(
-        cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
-    )
-    def are_users_erased(self, user_ids):
+    @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+    async def are_users_erased(self, user_ids):
         """
         Checks which users in a list have requested erasure
 
@@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
             user_ids (iterable[str]): full user id to check
 
         Returns:
-            Deferred[dict[str, bool]]:
+            dict[str, bool]:
                 for each user, whether the user has requested erasure.
         """
         # this serves the dual purpose of (a) making sure we can do len and
         # iterate it multiple times, and (b) avoiding duplicates.
         user_ids = tuple(set(user_ids))
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="erased_users",
             column="user_id",
             iterable=user_ids,
@@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
         )
         erased_users = {row["user_id"] for row in rows}
 
-        res = {u: u in erased_users for u in user_ids}
-        return res
+        return {u: u in erased_users for u in user_ids}
 
 
 class UserErasureStore(UserErasureWorkerStore):