summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-12 09:29:06 -0400
committerGitHub <noreply@github.com>2020-08-12 09:29:06 -0400
commitd68e10f308f89810e8d9ff94219cc68ca83f636d (patch)
tree6364d2a49d9f604c25142729514e95148ff9be90
parentConvert appservice, group server, profile and more databases to async (#8066) (diff)
downloadsynapse-d68e10f308f89810e8d9ff94219cc68ca83f636d.tar.xz
Convert account data, device inbox, and censor events databases to async/await (#8063)
-rw-r--r--changelog.d/8063.misc1
-rw-r--r--synapse/storage/databases/main/account_data.py77
-rw-r--r--synapse/storage/databases/main/censor_events.py11
-rw-r--r--synapse/storage/databases/main/deviceinbox.py94
-rw-r--r--tests/handlers/test_typing.py3
5 files changed, 99 insertions, 87 deletions
diff --git a/changelog.d/8063.misc b/changelog.d/8063.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8063.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index cf039e7f7d..82aac2bbf3 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,15 +16,16 @@
 
 import abc
 import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
 
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -97,13 +98,15 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
-    @cachedInlineCallbacks(num_args=2, max_entries=5000)
-    def get_global_account_data_by_type_for_user(self, data_type, user_id):
+    @cached(num_args=2, max_entries=5000)
+    async def get_global_account_data_by_type_for_user(
+        self, data_type: str, user_id: str
+    ) -> Optional[JsonDict]:
         """
         Returns:
-            Deferred: A dict
+            The account data.
         """
-        result = yield self.db_pool.simple_select_one_onecol(
+        result = await self.db_pool.simple_select_one_onecol(
             table="account_data",
             keyvalues={"user_id": user_id, "account_data_type": data_type},
             retcol="content",
@@ -280,9 +283,11 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
-    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
-    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
-        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+    @cached(num_args=2, cache_context=True, max_entries=5000)
+    async def is_ignored_by(
+        self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
+    ) -> bool:
+        ignored_account_data = await self.get_global_account_data_by_type_for_user(
             "m.ignored_user_list",
             ignorer_user_id,
             on_invalidate=cache_context.invalidate,
@@ -307,24 +312,27 @@ class AccountDataStore(AccountDataWorkerStore):
 
         super(AccountDataStore, self).__init__(database, db_conn, hs)
 
-    def get_max_account_data_stream_id(self):
+    def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream id for the private user data stream
 
         Returns:
-            A deferred int.
+            The maximum stream ID.
         """
         return self._account_data_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
+    async def add_account_data_to_room(
+        self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
         """Add some account_data to a room for a user.
+
         Args:
-            user_id(str): The user to add a tag for.
-            room_id(str): The room to add a tag for.
-            account_data_type(str): The type of account_data to add.
-            content(dict): A json object to associate with the tag.
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
         Returns:
-            A deferred that completes once the account_data has been added.
+            The maximum stream ID.
         """
         content_json = json_encoder.encode(content)
 
@@ -332,7 +340,7 @@ class AccountDataStore(AccountDataWorkerStore):
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
-            yield self.db_pool.simple_upsert(
+            await self.db_pool.simple_upsert(
                 desc="add_room_account_data",
                 table="room_account_data",
                 keyvalues={
@@ -350,7 +358,7 @@ class AccountDataStore(AccountDataWorkerStore):
             # doesn't sound any worse than the whole update getting lost,
             # which is what would happen if we combined the two into one
             # transaction.
-            yield self._update_max_stream_id(next_id)
+            await self._update_max_stream_id(next_id)
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
@@ -359,18 +367,20 @@ class AccountDataStore(AccountDataWorkerStore):
                 (user_id, room_id, account_data_type), content
             )
 
-        result = self._account_data_id_gen.get_current_token()
-        return result
+        return self._account_data_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def add_account_data_for_user(self, user_id, account_data_type, content):
+    async def add_account_data_for_user(
+        self, user_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
         """Add some account_data to a room for a user.
+
         Args:
-            user_id(str): The user to add a tag for.
-            account_data_type(str): The type of account_data to add.
-            content(dict): A json object to associate with the tag.
+            user_id: The user to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
         Returns:
-            A deferred that completes once the account_data has been added.
+            The maximum stream ID.
         """
         content_json = json_encoder.encode(content)
 
@@ -378,7 +388,7 @@ class AccountDataStore(AccountDataWorkerStore):
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
-            yield self.db_pool.simple_upsert(
+            await self.db_pool.simple_upsert(
                 desc="add_user_account_data",
                 table="account_data",
                 keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -396,7 +406,7 @@ class AccountDataStore(AccountDataWorkerStore):
             # Note: This is only here for backwards compat to allow admins to
             # roll back to a previous Synapse version. Next time we update the
             # database version we can remove this table.
-            yield self._update_max_stream_id(next_id)
+            await self._update_max_stream_id(next_id)
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
@@ -404,14 +414,13 @@ class AccountDataStore(AccountDataWorkerStore):
                 (account_data_type, user_id)
             )
 
-        result = self._account_data_id_gen.get_current_token()
-        return result
+        return self._account_data_id_gen.get_current_token()
 
-    def _update_max_stream_id(self, next_id):
+    def _update_max_stream_id(self, next_id: int):
         """Update the max stream_id
 
         Args:
-            next_id(int): The the revision to advance to.
+            next_id: The the revision to advance to.
         """
 
         # Note: This is only here for backwards compat to allow admins to
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 1de8249563..f211ddbaf8 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -16,8 +16,6 @@
 import logging
 from typing import TYPE_CHECKING
 
-from twisted.internet import defer
-
 from synapse.events.utils import prune_event_dict
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
@@ -148,17 +146,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             updatevalues={"json": pruned_json},
         )
 
-    @defer.inlineCallbacks
-    def expire_event(self, event_id):
+    async def expire_event(self, event_id: str) -> None:
         """Retrieve and expire an event that has expired, and delete its associated
         expiry timestamp. If the event can't be retrieved, delete its associated
         timestamp so we don't try to expire it again in the future.
 
         Args:
-             event_id (str): The ID of the event to delete.
+             event_id: The ID of the event to delete.
         """
         # Try to retrieve the event's content from the database or the event cache.
-        event = yield self.get_event(event_id)
+        event = await self.get_event(event_id)
 
         def delete_expired_event_txn(txn):
             # Delete the expiry timestamp associated with this event from the database.
@@ -193,7 +190,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 txn, "_get_event_cache", (event.event_id,)
             )
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_expired_event", delete_expired_event_txn
         )
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 76ec954f44..1f6e995c4f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -16,8 +16,6 @@
 import logging
 from typing import List, Tuple
 
-from twisted.internet import defer
-
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
@@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
-    def get_new_messages_for_device(
-        self, user_id, device_id, last_stream_id, current_stream_id, limit=100
-    ):
+    async def get_new_messages_for_device(
+        self,
+        user_id: str,
+        device_id: str,
+        last_stream_id: int,
+        current_stream_id: int,
+        limit: int = 100,
+    ) -> Tuple[List[dict], int]:
         """
         Args:
-            user_id(str): The recipient user_id.
-            device_id(str): The recipient device_id.
-            current_stream_id(int): The current position of the to device
+            user_id: The recipient user_id.
+            device_id: The recipient device_id.
+            last_stream_id: The last stream ID checked.
+            current_stream_id: The current position of the to device
                 message stream.
+            limit: The maximum number of messages to retrieve.
+
         Returns:
-            Deferred ([dict], int): List of messages for the device and where
-                in the stream the messages got to.
+            A list of messages for the device and where in the stream the messages got to.
         """
         has_changed = self._device_inbox_stream_cache.has_entity_changed(
             user_id, last_stream_id
         )
         if not has_changed:
-            return defer.succeed(([], current_stream_id))
+            return ([], current_stream_id)
 
         def get_new_messages_for_device_txn(txn):
             sql = (
@@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_new_messages_for_device", get_new_messages_for_device_txn
         )
 
     @trace
-    @defer.inlineCallbacks
-    def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+    async def delete_messages_for_device(
+        self, user_id: str, device_id: str, up_to_stream_id: int
+    ) -> int:
         """
         Args:
-            user_id(str): The recipient user_id.
-            device_id(str): The recipient device_id.
-            up_to_stream_id(int): Where to delete messages up to.
+            user_id: The recipient user_id.
+            device_id: The recipient device_id.
+            up_to_stream_id: Where to delete messages up to.
+
         Returns:
-            A deferred that resolves to the number of messages deleted.
+            The number of messages deleted.
         """
         # If we have cached the last stream id we've deleted up to, we can
         # check if there is likely to be anything that needs deleting
@@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        count = yield self.db_pool.runInteraction(
+        count = await self.db_pool.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
@@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return count
 
     @trace
-    def get_new_device_msgs_for_remote(
+    async def get_new_device_msgs_for_remote(
         self, destination, last_stream_id, current_stream_id, limit
-    ):
+    ) -> Tuple[List[dict], int]:
         """
         Args:
             destination(str): The name of the remote server.
@@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             current_stream_id(int|long): The current position of the device
                 message stream.
         Returns:
-            Deferred ([dict], int|long): List of messages for the device and where
-                in the stream the messages got to.
+            A list of messages for the device and where in the stream the messages got to.
         """
 
         set_tag("destination", destination)
@@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         )
         if not has_changed or last_stream_id == current_stream_id:
             log_kv({"message": "No new messages in stream"})
-            return defer.succeed(([], current_stream_id))
+            return ([], current_stream_id)
 
         if limit <= 0:
             # This can happen if we run out of room for EDUs in the transaction.
-            return defer.succeed(([], last_stream_id))
+            return ([], last_stream_id)
 
         @trace
         def get_new_messages_for_remote_destination_txn(txn):
@@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_new_device_msgs_for_remote",
             get_new_messages_for_remote_destination_txn,
         )
@@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
-    @defer.inlineCallbacks
-    def _background_drop_index_device_inbox(self, progress, batch_size):
+    async def _background_drop_index_device_inbox(self, progress, batch_size):
         def reindex_txn(conn):
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
             txn.close()
 
-        yield self.db_pool.runWithConnection(reindex_txn)
+        await self.db_pool.runWithConnection(reindex_txn)
 
-        yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
 
         return 1
 
@@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
         )
 
     @trace
-    @defer.inlineCallbacks
-    def add_messages_to_device_inbox(
-        self, local_messages_by_user_then_device, remote_messages_by_destination
-    ):
+    async def add_messages_to_device_inbox(
+        self,
+        local_messages_by_user_then_device: dict,
+        remote_messages_by_destination: dict,
+    ) -> int:
         """Used to send messages from this server.
 
         Args:
-            sender_user_id(str): The ID of the user sending these messages.
-            local_messages_by_user_and_device(dict):
+            local_messages_by_user_and_device:
                 Dictionary of user_id to device_id to message.
-            remote_messages_by_destination(dict):
+            remote_messages_by_destination:
                 Dictionary of destination server_name to the EDU JSON to send.
+
         Returns:
-            A deferred stream_id that resolves when the messages have been
-            inserted.
+            The new stream_id.
         """
 
         def add_messages_txn(txn, now_ms, stream_id):
@@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
             )
             for user_id in local_messages_by_user_then_device.keys():
@@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         return self._device_inbox_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def add_messages_from_remote_to_device_inbox(
-        self, origin, message_id, local_messages_by_user_then_device
-    ):
+    async def add_messages_from_remote_to_device_inbox(
+        self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+    ) -> int:
         def add_messages_txn(txn, now_ms, stream_id):
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
@@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,
                 now_ms,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index b7d0adb10e..64ddd8243d 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import register_federation_servlets
 
@@ -151,7 +152,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.datastore.get_current_state_deltas.return_value = (0, None)
 
         self.datastore.get_to_device_stream_token = lambda: 0
-        self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+        self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
             ([], 0)
         )
         self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None