summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8100.misc1
-rw-r--r--synapse/storage/database.py23
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py16
-rw-r--r--synapse/storage/databases/main/registration.py8
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--tests/handlers/test_profile.py4
-rw-r--r--tests/handlers/test_typing.py2
-rw-r--r--tests/storage/test_appservice.py16
-rw-r--r--tests/storage/test_base.py16
-rw-r--r--tests/storage/test_event_push_actions.py30
-rw-r--r--tests/storage/test_main.py2
-rw-r--r--tests/storage/test_profile.py4
13 files changed, 69 insertions, 59 deletions
diff --git a/changelog.d/8100.misc b/changelog.d/8100.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8100.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4ada6f5563..8a9e06efcf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -332,8 +332,7 @@ class DatabasePool(object):
         """
         return self._db_pool.running
 
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
+    async def _check_safe_to_upsert(self):
         """
         Is it safe to use native UPSERT?
 
@@ -342,7 +341,7 @@ class DatabasePool(object):
 
         If the background updates have not completed, wait 15 sec and check again.
         """
-        updates = yield self.simple_select_list(
+        updates = await self.simple_select_list(
             "background_updates",
             keyvalues=None,
             retcols=["update_name"],
@@ -614,8 +613,7 @@ class DatabasePool(object):
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    @defer.inlineCallbacks
-    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+    async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
         """Executes an INSERT query on the named table.
 
         Args:
@@ -631,7 +629,7 @@ class DatabasePool(object):
             `or_ignore` is True
         """
         try:
-            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+            await self.runInteraction(desc, self.simple_insert_txn, table, values)
         except self.engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
@@ -684,8 +682,7 @@ class DatabasePool(object):
 
         txn.executemany(sql, vals)
 
-    @defer.inlineCallbacks
-    def simple_upsert(
+    async def simple_upsert(
         self,
         table,
         keyvalues,
@@ -714,14 +711,14 @@ class DatabasePool(object):
                 inserting
             lock (bool): True to lock the table when doing the upsert.
         Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
+            None or bool: Native upserts always return None. Emulated
             upserts return True if a new entry was created, False if an existing
             one was updated.
         """
         attempts = 0
         while True:
             try:
-                result = yield self.runInteraction(
+                return await self.runInteraction(
                     desc,
                     self.simple_upsert_txn,
                     table,
@@ -730,7 +727,6 @@ class DatabasePool(object):
                     insertion_values,
                     lock=lock,
                 )
-                return result
             except self.engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
@@ -1121,8 +1117,7 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    @defer.inlineCallbacks
-    def simple_select_many_batch(
+    async def simple_select_many_batch(
         self,
         table,
         column,
@@ -1156,7 +1151,7 @@ class DatabasePool(object):
             it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
-            rows = yield self.runInteraction(
+            rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
                 table,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..02568a2391 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
             service(ApplicationService): The service whose state to set.
             state(ApplicationServiceState): The connectivity state to apply.
         Returns:
-            A Deferred which resolves when the state was set successfully.
+            An Awaitable which resolves when the state was set successfully.
         """
         return self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5687448e3d..8c63a0dc4d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield self.db_pool.simple_select_many_batch(
-            table="events",
-            retcols=("event_id",),
-            column="event_id",
-            iterable=list(event_ids),
-            keyvalues={"outlier": False},
-            desc="have_events_in_timeline",
+        rows = yield defer.ensureDeferred(
+            self.db_pool.simple_select_many_batch(
+                table="events",
+                retcols=("event_id",),
+                column="event_id",
+                iterable=list(event_ids),
+                keyvalues={"outlier": False},
+                desc="have_events_in_timeline",
+            )
         )
 
         return {r["event_id"] for r in rows}
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index de50fa6e94..068ad22b30 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@
 
 import logging
 import re
-from typing import Dict, List, Optional
-
-from twisted.internet.defer import Deferred
+from typing import Awaitable, Dict, List, Optional
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             id_server (str)
 
         Returns:
-            Deferred
+            Awaitable
         """
         # We need to use an upsert, in case they user had already bound the
         # threepid
@@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
     def record_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
-    ) -> Deferred:
+    ) -> Awaitable:
         """Record a mapping from an external user id to a mxid
 
         Args:
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 1cc8c08ed0..161edbeccb 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
-    def get_membership_from_event_ids(
+    async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
     ) -> List[dict]:
         """Get user_id and membership of a set of event IDs.
         """
 
-        return self.db_pool.simple_select_many_batch(
+        return await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d70e1fc608..b609b30d4a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        yield self.store.create_profile(self.frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
 
         self.handler = hs.get_profile_handler()
         self.hs = hs
@@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield self.store.create_profile("caroline")
+        yield defer.ensureDeferred(self.store.create_profile("caroline"))
         yield self.store.set_profile_displayname("caroline", "Caroline")
 
         response = yield defer.ensureDeferred(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 64afd581bc..e01de158e5 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ([], 0)
         )
         self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
-        self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+        self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
             None
         )
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 98b74890d5..a425e66f37 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_set_appservices_state_down(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_set_appservices_state_multiple_up(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index efcaeef1e7..13bcac743a 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db_pool.simple_insert(
-            table="tablename", values={"columname": "Value"}
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_insert(
+                table="tablename", values={"columname": "Value"}
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_3cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db_pool.simple_insert(
-            table="tablename",
-            # Use OrderedDict() so we can assert on the SQL generated
-            values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_insert(
+                table="tablename",
+                # Use OrderedDict() so we can assert on the SQL generated
+                values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 857db071d4..238bad5b45 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store.db_pool.simple_insert(
-                "events",
-                {
-                    "stream_ordering": so,
-                    "received_ts": ts,
-                    "event_id": "event%i" % so,
-                    "type": "",
-                    "room_id": "",
-                    "content": "",
-                    "processed": True,
-                    "outlier": False,
-                    "topological_ordering": 0,
-                    "depth": 0,
-                },
+            return defer.ensureDeferred(
+                self.store.db_pool.simple_insert(
+                    "events",
+                    {
+                        "stream_ordering": so,
+                        "received_ts": ts,
+                        "event_id": "event%i" % so,
+                        "type": "",
+                        "room_id": "",
+                        "content": "",
+                        "processed": True,
+                        "outlier": False,
+                        "topological_ordering": 0,
+                        "depth": 0,
+                    },
+                )
             )
 
         # start with the base case where there are no events in the table
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..fbf8af940a 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_users_paginate(self):
         yield self.store.register_user(self.user.to_string(), "pass")
-        yield self.store.create_profile(self.user.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
         yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
 
         users, total = yield self.store.get_users_paginate(
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..9d5b8aa47d 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_displayname(self):
-        yield self.store.create_profile(self.u_frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
         yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
 
@@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield self.store.create_profile(self.u_frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
         yield self.store.set_profile_avatar_url(
             self.u_frank.localpart, "http://my.site/here"